diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala index 905c1c905e01..84085d061847 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreator.scala @@ -58,9 +58,8 @@ class AstCreator( fileContent.foreach(fileNode.content(_)) val namespaceBlock = globalNamespaceBlock() methodAstParentStack.push(namespaceBlock) - val astForFakeMethod = - astInFakeMethod(namespaceBlock.fullName, parserResult.filename, parserResult.ast) - val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(astForFakeMethod)) + val astForFakeMethod = astInFakeMethod(namespaceBlock.fullName, parserResult.filename, parserResult.ast) + val ast = Ast(fileNode).withChild(Ast(namespaceBlock).withChild(astForFakeMethod)) Ast.storeInDiffGraph(ast, diffGraph) scope.createVariableReferenceLinks(diffGraph, parserResult.filename) diffGraph @@ -125,4 +124,5 @@ class AstCreator( case _ => shortenCode(new String(code)).stripLineEnd } } + } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala index cc2fef462721..3bc33608af91 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstCreatorHelper.scala @@ -3,7 +3,10 @@ package io.joern.swiftsrc2cpg.astcreation import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.* import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines import io.joern.x2cpg.{Ast, ValidationMode} -import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, EvaluationStrategies, PropertyNames} +import io.joern.x2cpg.datastructures.Stack.* +import io.joern.x2cpg.datastructures.VariableScopeManager +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.* import org.apache.commons.lang3.StringUtils object AstCreatorHelper { @@ -147,21 +150,63 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As otherElements(indexOfGuardStmt).asInstanceOf[CodeBlockItemSyntax].item.asInstanceOf[GuardStmtSyntax] val elementsAfterGuard = otherElements.slice(indexOfGuardStmt + 1, otherElements.size) - val code = this.code(guardStmt) - val ifNode = controlStructureNode(guardStmt, ControlStructureTypes.IF, code) - val conditionAst = astForNode(guardStmt.conditions) + val code = this.code(guardStmt) + val ifNode = controlStructureNode(guardStmt, ControlStructureTypes.IF, code) + + // Apply optional binding desugaring for guard let + // Create the block that will hold the unwrapped variables (blockNode argument is only used for location info) + val thenBlockNode = + if (elementsAfterGuard.nonEmpty) blockNode(elementsAfterGuard.head) else blockNode(guardStmt) + + val (conditionAst, unwrapAsts) = handleOptionalBindingConditions( + guardStmt.conditions.children, + onAllSimple = simpleBindings => { + val bindingInfos = collectBindingInfos(simpleBindings) + val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwraps = buildUnwrapAssignments(bindingInfos) + (condAst, unwraps) + }, + onMixed = (simpleBindings, tupleBindings) => { + val bindingInfos = collectBindingInfos(simpleBindings) + val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + (condAst, unwraps) + }, + onPartial = (simpleBindings, tupleBindings, otherConditions) => { + val bindingInfos = collectBindingInfos(simpleBindings) + val condAst = buildOptionalBindingCondition(guardStmt, bindingInfos, otherConditions) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwraps = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + (condAst, unwraps) + }, + onStandard = () => { + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val condAst = astForNode(guardStmt.conditions) + (condAst, List.empty) + } + ) + + val allThenChildren = unwrapAsts ++ astsForBlockElements(elementsAfterGuard) ++ deferElementsAstsOrdered + + // Closing the scope opened at the handleOptionalBindingConditions handler + scope.popScope() + localAstParentStack.pop() - val thenAst = astsForBlockElements(elementsAfterGuard) ++ deferElementsAstsOrdered match { + val thenAst = allThenChildren match { case Nil => - blockAst(blockNode(guardStmt), List.empty) - case blockElement :: Nil => + blockAst(thenBlockNode, List.empty) + case blockElement :: Nil if unwrapAsts.isEmpty => blockElement case blockChildren => - val block = blockNode(elementsAfterGuard.head) - blockAst(block, blockChildren) + blockAst(thenBlockNode, blockChildren) } val elseAst = astForNode(guardStmt.body) - setOrderExplicitly(elseAst, 3) val ifAst = ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) astsForBlockElements(elementsBeforeGuard) :+ ifAst @@ -516,4 +561,296 @@ trait AstCreatorHelper(implicit withSchemaValidation: ValidationMode) { this: As } } + /** Checks if a pattern is tuple-like (direct TuplePatternSyntax or ExpressionPatternSyntax wrapping TupleExprSyntax). + * Used to detect tuple patterns in optional binding conditions. + */ + protected def isTupleLikePattern(pattern: PatternSyntax): Boolean = pattern match { + case _: TuplePatternSyntax => true + case ep: ExpressionPatternSyntax if ep.expression.isInstanceOf[TupleExprSyntax] => true + case _ => false + } + + /** Information about an optional binding for desugaring if-let/while-let constructs. + * + * Contains the CPG-level names for optional binding desugaring. The localName is the unwrapped variable name that + * appears in the then/body block. The tmpName (when present) holds the optional value from the initializer and is + * used for the nil check in the condition block, ensuring single evaluation of the initializer expression. + * + * @param localName + * CPG name for the unwrapped variable in the then/body block (e.g., "a" in `if let a = foo()`) + * @param tmpName + * CPG name for temporary holding the optional value in condition (e.g., "0" in the nil check) + * @param binding + * The source-level OptionalBindingConditionSyntax node + * @param isWildcard + * True if the pattern is a wildcard (e.g., `if let _ = foo()`), requiring generated name + */ + protected case class BindingInfo( + localName: String, + tmpName: Option[String], + binding: OptionalBindingConditionSyntax, + isWildcard: Boolean + ) + + protected def collectBindingInfos(bindings: Seq[OptionalBindingConditionSyntax]): Seq[BindingInfo] = { + bindings.map { binding => + val (localName, isWildcard) = binding.pattern match { + case ident: IdentifierPatternSyntax => (code(ident.identifier), false) + case _ => (scopeLocalUniqueName("wildcard"), true) + } + val tmpName = binding.initializer.map(_ => scopeLocalUniqueName("tmp")) + BindingInfo(localName, tmpName, binding, isWildcard) + } + } + + protected def handleOptionalBindingConditions[T]( + conditions: Iterable[ConditionElementSyntax], + onAllSimple: Seq[OptionalBindingConditionSyntax] => T, + onMixed: (Seq[OptionalBindingConditionSyntax], Seq[OptionalBindingConditionSyntax]) => T, + onPartial: ( + Seq[OptionalBindingConditionSyntax], + Seq[OptionalBindingConditionSyntax], + Seq[ConditionElementSyntax] + ) => T, + onStandard: () => T + ): T = { + val conditionsSeq = conditions.toSeq + val optionalBindings = conditionsSeq.collect { + case condElem if condElem.condition.isInstanceOf[OptionalBindingConditionSyntax] => + condElem.condition.asInstanceOf[OptionalBindingConditionSyntax] + } + + val (simpleBindings, tupleBindings) = optionalBindings.partition(binding => !isTupleLikePattern(binding.pattern)) + val otherConditions = + conditionsSeq.filterNot(condElem => condElem.condition.isInstanceOf[OptionalBindingConditionSyntax]) + + if (simpleBindings.isEmpty) { + // No simple bindings to desugar + onStandard() + } else if (otherConditions.isEmpty && tupleBindings.isEmpty) { + // All conditions are simple optional bindings + onAllSimple(simpleBindings) + } else if (otherConditions.isEmpty) { + // All conditions are optional bindings (mixed simple + tuple) + onMixed(simpleBindings, tupleBindings) + } else { + // Partial: simple bindings + other conditions (and maybe tuple bindings) + onPartial(simpleBindings, tupleBindings, otherConditions) + } + } + + /** Combines multiple nil check ASTs with logical AND operator. If only one check exists, returns it directly. + * + * @param node + * The control structure node for creating operator nodes + * @param nilCheckAsts + * The nil check ASTs to combine + * @return + * Single AST representing all checks combined with && + */ + private def combineNilChecksWithAnd(node: SwiftNode, nilCheckAsts: Seq[Ast]): Ast = { + if (nilCheckAsts.size == 1) { + nilCheckAsts.head + } else { + nilCheckAsts.reduce { (left, right) => + val leftCode = left.root.map(codeOf).getOrElse("") + val rightCode = right.root.map(codeOf).getOrElse("") + val andCode = s"($leftCode) && ($rightCode)" + val andCallNode = createStaticCallNode(node, andCode, Operators.logicalAnd, Operators.logicalAnd, Defines.Bool) + callAst(andCallNode, List(left, right)) + } + } + } + + /** Builds the condition AST for optional binding constructs (if-let/while-let). If any binding has an initializer, + * creates a block with temp variable assignments and nil checks. Otherwise creates a simple combined nil check. + * + * @param node + * The control structure node (IfExprSyntax or WhileStmtSyntax) + * @param bindingInfos + * Information about each optional binding + * @param additionalConditions + * Additional condition elements to AND with the nil checks + * @return + * Condition AST (either block or direct nil check) + */ + protected def buildOptionalBindingCondition( + node: SwiftNode, + bindingInfos: Seq[BindingInfo], + additionalConditions: Seq[ConditionElementSyntax] = Seq.empty + ): Ast = { + val hasAnyInitializer = bindingInfos.exists(_.tmpName.isDefined) + + if (hasAnyInitializer) { + bindingInfos.foreach { info => + info.tmpName.foreach { tmpName => + val tmpLocalNode = localNode(info.binding, tmpName, tmpName, Defines.Any).order(0) + diffGraph.addEdge(localAstParentStack.head, tmpLocalNode, EdgeTypes.AST) + scope.addVariable(tmpName, tmpLocalNode, Defines.Any, VariableScopeManager.ScopeType.BlockScope) + } + } + + val condBlockNode = blockNode(node) + scope.pushNewBlockScope(condBlockNode) + localAstParentStack.push(condBlockNode) + + val nilCheckAsts = bindingInfos.map { info => + info.tmpName match { + case Some(tmpName) => + val tmpIdentNode = identifierNode(info.binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdentNode, Defines.Any, EvaluationStrategies.BY_REFERENCE) + val initAst = astForNode(info.binding.initializer.get.value) + val assignAst = createAssignmentCallAst( + info.binding, + Ast(tmpIdentNode), + initAst, + s"$tmpName = ${code(info.binding.initializer.get.value)}" + ) + + val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) + val checkCallNode = createStaticCallNode( + info.binding, + s"($tmpName = ${code(info.binding.initializer.get.value)}) != nil", + Operators.notEquals, + Operators.notEquals, + Defines.Bool + ) + callAst(checkCallNode, List(assignAst, Ast(nilNode))) + + case None => + val patternAst = astForNode(info.binding.pattern) + val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) + val checkCallNode = createStaticCallNode( + info.binding, + s"${info.localName} != nil", + Operators.notEquals, + Operators.notEquals, + Defines.Bool + ) + callAst(checkCallNode, List(patternAst, Ast(nilNode))) + } + } + + val additionalConditionAsts = additionalConditions.map(condElem => astForNode(condElem.condition)) + val allChecks = nilCheckAsts ++ additionalConditionAsts + val combinedCheckAst = combineNilChecksWithAnd(node, allChecks) + + scope.popScope() + localAstParentStack.pop() + + blockAst(condBlockNode, List(combinedCheckAst)) + } else { + val nilCheckAsts = bindingInfos.map { info => + val patternAst = astForNode(info.binding.pattern) + val nilNode = literalNode(info.binding, "nil", Option(Defines.Nil)) + val checkCallNode = createStaticCallNode( + info.binding, + s"${info.localName} != nil", + Operators.notEquals, + Operators.notEquals, + Defines.Bool + ) + callAst(checkCallNode, List(patternAst, Ast(nilNode))) + } + + val additionalConditionAsts = additionalConditions.map(condElem => astForNode(condElem.condition)) + val allChecks = nilCheckAsts ++ additionalConditionAsts + combineNilChecksWithAnd(node, allChecks) + } + } + + /** Builds the body AST with optional unwrapping assignments prepended. For bindings with initializers, creates locals + * and unwrapping assignments in the body block. + * + * @param bodyNode + * The body syntax node for creating the block + * @param bodyStatements + * The statements to include in the body + * @param bindingInfos + * Information about each optional binding + * @return + * Body AST with unwrapping assignments (if needed) followed by original body statements + */ + protected def buildBodyWithUnwrapping( + bodyNode: SwiftNode, + bodyStatements: Iterable[SwiftNode], + bindingInfos: Seq[BindingInfo] + ): Ast = { + val bindingsWithInitializer = bindingInfos.filter(info => info.tmpName.isDefined && !info.isWildcard) + + if (bindingsWithInitializer.nonEmpty) { + val bodyBlockNode = blockNode(bodyNode) + scope.pushNewBlockScope(bodyBlockNode) + localAstParentStack.push(bodyBlockNode) + + val unwrapAsts = bindingsWithInitializer.map { info => + val binding = info.binding + val tmpName = info.tmpName.get + + val typeFullName = + binding.typeAnnotation.fold(Defines.Any)(typeAnn => AstCreatorHelper.cleanType(code(typeAnn.`type`))) + val kind = code(binding.bindingSpecifier) + val scopeType = + if (kind == "let") VariableScopeManager.ScopeType.BlockScope + else VariableScopeManager.ScopeType.MethodScope + + val tpeFromTypeMap = fullnameProvider.typeFullname(binding.pattern).getOrElse(typeFullName) + registerType(tpeFromTypeMap) + val localNode_ = localNode(binding, info.localName, info.localName, tpeFromTypeMap).order(0) + scope.addVariable(info.localName, localNode_, tpeFromTypeMap, scopeType) + diffGraph.addEdge(bodyBlockNode, localNode_, EdgeTypes.AST) + + val localIdentNode = identifierNode(binding, info.localName, info.localName, tpeFromTypeMap) + scope.addVariableReference(info.localName, localIdentNode, tpeFromTypeMap, EvaluationStrategies.BY_REFERENCE) + + val tmpIdent = identifierNode(binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdent, Defines.Any, EvaluationStrategies.BY_REFERENCE) + createAssignmentCallAst(binding, Ast(localIdentNode), Ast(tmpIdent), s"${info.localName} = $tmpName") + } + + val bodyAsts = bodyStatements.map(astForNode).toList + + scope.popScope() + localAstParentStack.pop() + + if (unwrapAsts.isEmpty && bodyAsts.isEmpty) { + // Empty body - return empty block without creating a nested block + Ast(bodyBlockNode) + } else { + blockAst(bodyBlockNode, unwrapAsts.toList ++ bodyAsts) + } + } else { + astForNode(bodyNode) + } + } + + protected def buildUnwrapAssignments(bindingInfos: Seq[BindingInfo]): List[Ast] = { + val bindingsWithInitializer = bindingInfos.filter(info => info.tmpName.isDefined && !info.isWildcard) + + bindingsWithInitializer.map { info => + val binding = info.binding + val tmpName = info.tmpName.get + + val typeFullName = + binding.typeAnnotation.fold(Defines.Any)(typeAnn => AstCreatorHelper.cleanType(code(typeAnn.`type`))) + val kind = code(binding.bindingSpecifier) + val scopeType = + if (kind == "let") VariableScopeManager.ScopeType.BlockScope + else VariableScopeManager.ScopeType.MethodScope + + val tpeFromTypeMap = fullnameProvider.typeFullname(binding.pattern).getOrElse(typeFullName) + registerType(tpeFromTypeMap) + val localNode_ = localNode(binding, info.localName, info.localName, tpeFromTypeMap).order(0) + scope.addVariable(info.localName, localNode_, tpeFromTypeMap, scopeType) + diffGraph.addEdge(localAstParentStack.head, localNode_, EdgeTypes.AST) + + val localIdentNode = identifierNode(binding, info.localName, info.localName, tpeFromTypeMap) + scope.addVariableReference(info.localName, localIdentNode, tpeFromTypeMap, EvaluationStrategies.BY_REFERENCE) + + val tmpIdent = identifierNode(binding, tmpName, tmpName, Defines.Any) + scope.addVariableReference(tmpName, tmpIdent, Defines.Any, EvaluationStrategies.BY_REFERENCE) + createAssignmentCallAst(binding, Ast(localIdentNode), Ast(tmpIdent), s"${info.localName} = $tmpName") + }.toList + } + } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala index 5dfc7efa6f4b..6a2b86912731 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForExprSyntaxCreator.scala @@ -8,7 +8,7 @@ import io.joern.x2cpg.datastructures.VariableScopeManager import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.{ExpressionNew, NewCall} +import io.shiftleft.codepropertygraph.generated.nodes.{ExpressionNew, NewCall, NewControlStructure} import scala.annotation.{tailrec, unused} @@ -498,15 +498,114 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } private def astForIfExprSyntax(node: IfExprSyntax): Ast = { - val code = this.code(node) - val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) - val conditionAstRaw = astForNode(node.conditions) - val conditionAst = conditionAstRaw.root match { - case Some(_) => conditionAstRaw - case None => blockAst(blockNode(node.conditions), List.empty) - } - val thenAst = astForNode(node.body) - val elseAst = node.elseBody.map(astForNode) + val code = this.code(node) + val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) + + handleOptionalBindingConditions( + node.conditions.children, + onAllSimple = simpleBindings => astForIfLetExprSyntax(node, ifNode, simpleBindings, node.body, node.elseBody), + onMixed = (simpleBindings, tupleBindings) => + astForIfLetExprSyntaxMixed(node, ifNode, simpleBindings, tupleBindings, node.body, node.elseBody), + onPartial = (simpleBindings, tupleBindings, otherConditions) => + astForIfLetExprSyntaxPartial( + node, + ifNode, + simpleBindings, + tupleBindings, + otherConditions, + node.body, + node.elseBody + ), + onStandard = () => { + val conditionAst = astForNode(node.conditions) + val thenAst = astForNode(node.body) + val elseAst = node.elseBody.map(astForNode) + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) + } + ) + } + + /** Handles Swift optional binding (if-let) constructs. + * + * De-sugars `if let baz = foo() { body }` into: + * + * Condition: { (0 = foo()) != nil } + * + * Then block: { let baz = 0; body } + * + * For multiple bindings `if let a = foo(), let b = bar() { body }`: + * + * Condition: { (0 = foo()) != nil && (1 = bar()) != nil } + * + * Then block: { a = 0; b = 1; body } + * + * For mixed cases with/without initializers `if let a = foo(), let b { body }`: + * + * Condition: { (0 = foo()) != nil && b != nil } + * + * Then block: { a = 0; body } + */ + private def astForIfLetExprSyntax( + node: IfExprSyntax, + ifNode: NewControlStructure, + optionalBindings: Seq[OptionalBindingConditionSyntax], + thenBody: CodeBlockSyntax, + elseBody: Option[IfExprSyntax | CodeBlockSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(optionalBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val thenAst = buildBodyWithUnwrapping(thenBody, thenBody.statements.children, bindingInfos) + val elseAst = elseBody.map(astForNode) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) + } + + /** Handles mixed optional binding constructs with both simple and tuple patterns. + * + * De-sugars `if let a = foo(), let (b, c) = bar() { body }` into: + * + * Condition: { (0 = foo()) != nil } + * + * Then block: { let a = 0; let (b, c) = bar(); body } + */ + private def astForIfLetExprSyntaxMixed( + node: IfExprSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + thenBody: CodeBlockSyntax, + elseBody: Option[IfExprSyntax | CodeBlockSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val thenAst = buildBodyWithUnwrapping(thenBody, tupleBindings ++ thenBody.statements.children, bindingInfos) + val elseAst = elseBody.map(astForNode) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) + } + + /** Handles partial optional binding desugaring with other conditions. + * + * De-sugars `if let a = foo(), #unavailable(...) { body }` into: + * + * Condition: { ((0 = foo()) != nil) && #unavailable(...) } + * + * Then block: { let a = 0; body } + */ + private def astForIfLetExprSyntaxPartial( + node: IfExprSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + otherConditions: Seq[ConditionElementSyntax], + thenBody: CodeBlockSyntax, + elseBody: Option[IfExprSyntax | CodeBlockSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos, otherConditions) + val thenAst = buildBodyWithUnwrapping(thenBody, tupleBindings ++ thenBody.statements.children, bindingInfos) + val elseAst = elseBody.map(astForNode) + ifThenElseAst(ifNode, Option(conditionAst), thenAst, elseAst) } @@ -723,7 +822,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * identifier/field-identifier nodes so the resulting AST can be safely used as an argument without node-sharing * issues. */ - protected def createFieldAccessChain(baseName: String, fields: List[String], node: SwiftNode): Ast = { + private def createFieldAccessChain(baseName: String, fields: List[String], node: SwiftNode): Ast = { val baseAst = Ast(identifierNode(node, baseName)) fields.foldLeft(baseAst) { (accAst, field) => createFieldAccessCallAst(node, accAst, fieldIdentifierNode(node, field, field)) @@ -826,7 +925,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { */ /** Creates an instanceOf check for an IsTypePatternSyntax against a subject field access. */ - protected def astForIsTypePatternInTupleContext( + private def astForIsTypePatternInTupleContext( isType: IsTypePatternSyntax, subjectAst: Ast, subjectCode: String, @@ -843,7 +942,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Creates an equality check for an expression pattern against a subject field access. */ - protected def astForExpressionPatternInTupleContext( + private def astForExpressionPatternInTupleContext( ep: ExpressionPatternSyntax, subjectAst: Ast, subjectCode: String, @@ -856,7 +955,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Creates a variable binding assignment for a pattern element against a subject field access. */ - protected def astForBindingInTupleContext( + private def astForBindingInTupleContext( varName: String, subjectAst: Ast, subjectCode: String, @@ -903,7 +1002,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Determines whether an expression inside a tuple represents a binding (let/var pattern). */ - protected def isBindingExpression(expr: ExprSyntax): Boolean = expr match { + private def isBindingExpression(expr: ExprSyntax): Boolean = expr match { case p: PatternExprSyntax => p.pattern match { case _: ValueBindingPatternSyntax => true @@ -914,7 +1013,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } /** Dispatches a PatternSyntax inside a tuple context to the appropriate de-sugaring. */ - protected def astsForPatternInTupleContext( + private def astsForPatternInTupleContext( pattern: PatternSyntax, subjectAst: Ast, subjectCode: String, @@ -967,7 +1066,7 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { * - `DeclReferenceExprSyntax` (`a` in `case let (a, b):`) * - `PatternExprSyntax(ValueBindingPatternSyntax(IdentifierPatternSyntax))` (`var a` in `case (var a, var b):`) */ - protected def extractBindingName(expr: ExprSyntax): String = { + private def extractBindingName(expr: ExprSyntax): String = { expr match { case d: DeclReferenceExprSyntax => code(d) case p: PatternExprSyntax => @@ -1056,7 +1155,6 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { val condIdentNode = identifierNode(node, subjectTmpName, subjectTmpName, Defines.Tuple) scope.addVariableReference(subjectTmpName, condIdentNode, Defines.Tuple, EvaluationStrategies.BY_REFERENCE) val condAst = Ast(condIdentNode) - setOrderExplicitly(condAst, 1) val switchBlockNode = blockNode(node).order(2) scope.pushNewBlockScope(switchBlockNode) @@ -1073,15 +1171,10 @@ trait AstForExprSyntaxCreator(implicit withSchemaValidation: ValidationMode) { blockAst(outerBlockNode, List(subjectAssignAst, switchAstResult)) } else { - // The semantics of switch statement children is partially defined by their order value. - // The blockAst must have order == 2. Only to avoid collision we set switchExpressionAst to 1 - // because the semantics of it is already indicated via the condition edge. - val switchNode = controlStructureNode(node, ControlStructureTypes.SWITCH, code(node)) - + val switchNode = controlStructureNode(node, ControlStructureTypes.SWITCH, code(node)) val switchExpressionAst = astForNode(node.subject) - setOrderExplicitly(switchExpressionAst, 1) - val blockNode_ = blockNode(node).order(2) + val blockNode_ = blockNode(node) scope.pushNewBlockScope(blockNode_) localAstParentStack.push(blockNode_) val casesAsts = cases.flatMap(astsForSwitchCase(_, None)) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala index c5aca3b607ae..ffc667163a5d 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstForStmtSyntaxCreator.scala @@ -2,16 +2,11 @@ package io.joern.swiftsrc2cpg.astcreation import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.* import io.joern.x2cpg -import io.joern.x2cpg.Ast -import io.joern.x2cpg.ValidationMode +import io.joern.x2cpg.{Ast, ValidationMode} import io.joern.x2cpg.datastructures.Stack.* import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines -import io.shiftleft.codepropertygraph.generated.ControlStructureTypes -import io.shiftleft.codepropertygraph.generated.nodes.NewJumpLabel -import io.shiftleft.codepropertygraph.generated.DispatchTypes -import io.shiftleft.codepropertygraph.generated.EdgeTypes -import io.shiftleft.codepropertygraph.generated.EvaluationStrategies -import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.* +import io.shiftleft.codepropertygraph.generated.nodes.{NewControlStructure, NewJumpLabel} import scala.annotation.unused @@ -561,11 +556,112 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } private def astForGuardStmtSyntax(node: GuardStmtSyntax): Ast = { - val code = this.code(node) - val ifNode = controlStructureNode(node, ControlStructureTypes.IF, code) - val conditionAst = astForNode(node.conditions) - val thenAst = blockAst(blockNode(node), List.empty) - val elseAst = astForNode(node.body) + // This is already handled in AstCreatorHelper.astsForBlockElements + Ast() + } + + /** Handles Swift optional binding (guard-let) constructs. + * + * De-sugars `guard let x = foo() else { exit }` into: + * + * Condition: { (0 = foo()) != nil } + * + * Then block: { let x = 0 } + * + * Else block: { exit } + * + * For multiple bindings `guard let a = foo(), let b = bar() else { exit }`: + * + * Condition: { ((0 = foo()) != nil) && ((1 = bar()) != nil) } + * + * Then block: { a = 0; b = 1 } + * + * For mixed cases with/without initializers `guard let a = foo(), let b else { exit }`: + * + * Condition: { ((0 = foo()) != nil) && (b != nil) } + * + * Then block: { a = 0 } + */ + private def astForGuardLetStmtSyntax( + node: GuardStmtSyntax, + ifNode: NewControlStructure, + optionalBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(optionalBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + + // For guard, the then block contains the unwrapped bindings in the parent scope + val thenBlockNode = blockNode(node) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwrapAsts = buildUnwrapAssignments(bindingInfos) + scope.popScope() + localAstParentStack.pop() + val thenAst = blockAst(thenBlockNode, unwrapAsts) + + val elseAst = astForNode(node.body) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) + } + + /** Handles mixed optional binding constructs with both simple and tuple patterns. + * + * De-sugars `guard let a = foo(), let (b, c) = bar() else { exit }` into: + * + * Condition: { (0 = foo()) != nil } + * + * Then block: { let a = 0; let (b, c) = bar() } + */ + private def astForGuardLetStmtSyntaxMixed( + node: GuardStmtSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + + val thenBlockNode = blockNode(node) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwrapAsts = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + scope.popScope() + localAstParentStack.pop() + val thenAst = blockAst(thenBlockNode, unwrapAsts) + + val elseAst = astForNode(node.body) + + ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) + } + + /** Handles partial optional binding desugaring with other conditions. + * + * De-sugars `guard let a = foo(), someCondition else { exit }` into: + * + * Condition: { ((0 = foo()) != nil) && someCondition } + * + * Then block: { let a = 0 } + */ + private def astForGuardLetStmtSyntaxPartial( + node: GuardStmtSyntax, + ifNode: NewControlStructure, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + otherConditions: Seq[ConditionElementSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos, otherConditions) + + val thenBlockNode = blockNode(node) + scope.pushNewBlockScope(thenBlockNode) + localAstParentStack.push(thenBlockNode) + val unwrapAsts = buildUnwrapAssignments(bindingInfos) ++ tupleBindings.map(astForNode) + scope.popScope() + localAstParentStack.pop() + val thenAst = blockAst(thenBlockNode, unwrapAsts) + + val elseAst = astForNode(node.body) + ifThenElseAst(ifNode, Option(conditionAst), thenAst, Option(elseAst)) } @@ -585,11 +681,9 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { private def astForRepeatStmtSyntax(node: RepeatStmtSyntax): Ast = { val code = this.code(node) // In Swift, a repeat-while loop is semantically the same as a C do-while loop - val doNode = controlStructureNode(node, ControlStructureTypes.DO, code) - val conditionAst = astForNode(node.condition) - val bodyAst = astForNode(node.body) - setOrderExplicitly(conditionAst, 1) - setOrderExplicitly(bodyAst, 2) + val doNode = controlStructureNode(node, ControlStructureTypes.DO, code) + val conditionAst = astForNode(node.condition) + val bodyAst = astForNode(node.body) val astWithChildren = controlStructureAst(doNode, Option(conditionAst), Seq(bodyAst), placeConditionLast = true) bodyAst.root match { case Some(bodyRoot) => astWithChildren.withDoBodyEdge(doNode, bodyRoot) @@ -624,13 +718,113 @@ trait AstForStmtSyntaxCreator(implicit withSchemaValidation: ValidationMode) { } private def astForWhileStmtSyntax(node: WhileStmtSyntax): Ast = { - val code = this.code(node) - val conditionAst = astForNode(node.conditions) - val bodyAst = astForNode(node.body) + val code = this.code(node) + + handleOptionalBindingConditions( + node.conditions.children, + onAllSimple = simpleBindings => astForWhileLetStmtSyntax(node, simpleBindings), + onMixed = (simpleBindings, tupleBindings) => astForWhileLetStmtSyntaxMixed(node, simpleBindings, tupleBindings), + onPartial = (simpleBindings, tupleBindings, otherConditions) => + astForWhileLetStmtSyntaxPartial(node, simpleBindings, tupleBindings, otherConditions), + onStandard = () => { + val conditionAst = astForNode(node.conditions) + val bodyAst = astForNode(node.body) + whileAst( + Option(conditionAst), + Seq(bodyAst), + code = Option(code), + lineNumber = line(node), + columnNumber = column(node) + ) + } + ) + } + + /** Handles Swift optional binding (while-let) constructs. + * + * De-sugars `while let item = iterator.next() { body }` into: + * + * Condition: { (0 = iterator.next()) != nil } + * + * Loop body: { let item = 0; body } + * + * For multiple bindings `while let a = foo(), let b = bar() { body }`: + * + * Condition: { ((0 = foo()) != nil) && ((1 = bar()) != nil) } + * + * Loop body: { a = 0; b = 1; body } + * + * For mixed cases with/without initializers `while let a = foo(), let b { body }`: + * + * Condition: { ((0 = foo()) != nil) && (b != nil) } + * + * Loop body: { a = 0; body } + */ + private def astForWhileLetStmtSyntax( + node: WhileStmtSyntax, + optionalBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(optionalBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val bodyAst = buildBodyWithUnwrapping(node.body, node.body.statements.children, bindingInfos) + + whileAst( + Option(conditionAst), + Seq(bodyAst), + code = Option(code(node)), + lineNumber = line(node), + columnNumber = column(node) + ) + } + + /** Handles mixed optional binding constructs with both simple and tuple patterns. + * + * De-sugars `while let a = foo(), let (b, c) = bar() { body }` into: + * + * Condition: { (0 = foo()) != nil } + * + * Loop body: { let a = 0; let (b, c) = bar(); body } + */ + private def astForWhileLetStmtSyntaxMixed( + node: WhileStmtSyntax, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos) + val bodyAst = buildBodyWithUnwrapping(node.body, tupleBindings ++ node.body.statements.children, bindingInfos) + whileAst( Option(conditionAst), Seq(bodyAst), - code = Option(code), + code = Option(code(node)), + lineNumber = line(node), + columnNumber = column(node) + ) + } + + /** Handles partial optional binding desugaring with other conditions. + * + * De-sugars `while let a = foo(), someCondition { body }` into: + * + * Condition: { ((0 = foo()) != nil) && someCondition } + * + * Loop body: { let a = 0; body } + */ + private def astForWhileLetStmtSyntaxPartial( + node: WhileStmtSyntax, + simpleBindings: Seq[OptionalBindingConditionSyntax], + tupleBindings: Seq[OptionalBindingConditionSyntax], + otherConditions: Seq[ConditionElementSyntax] + ): Ast = { + val bindingInfos = collectBindingInfos(simpleBindings) + val conditionAst = buildOptionalBindingCondition(node, bindingInfos, otherConditions) + val bodyAst = buildBodyWithUnwrapping(node.body, tupleBindings ++ node.body.statements.children, bindingInfos) + + whileAst( + Option(conditionAst), + Seq(bodyAst), + code = Option(code(node)), lineNumber = line(node), columnNumber = column(node) ) diff --git a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala index 41748c633180..08336eaf8827 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/main/scala/io/joern/swiftsrc2cpg/astcreation/AstNodeBuilder.scala @@ -1,7 +1,6 @@ package io.joern.swiftsrc2cpg.astcreation import io.joern.swiftsrc2cpg.parser.SwiftNodeSyntax.* -import io.joern.x2cpg import io.joern.x2cpg.frontendspecific.swiftsrc2cpg.Defines import io.joern.x2cpg.{Ast, ValidationMode} import io.shiftleft.codepropertygraph.generated.nodes.* @@ -9,10 +8,6 @@ import io.shiftleft.codepropertygraph.generated.{DispatchTypes, EdgeTypes, Opera trait AstNodeBuilder(implicit withSchemaValidation: ValidationMode) { this: AstCreator => - protected def setOrderExplicitly(ast: Ast, order: Int): Unit = { - ast.root.foreach { case expr: ExpressionNew => expr.order = order } - } - protected def codeOf(node: NewNode): String = node match { case astNodeNew: AstNodeNew => astNodeNew.code case _ => "" diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala index 854266e6f441..eabf474adc38 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/AvailabilityQueryTests.scala @@ -4,7 +4,6 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class AvailabilityQueryTests extends SwiftSrc2CpgSuite { @@ -40,14 +39,28 @@ class AvailabilityQueryTests extends SwiftSrc2CpgSuite { } "testAvailabilityQuery5b" in { - val cpg = code("if let _ = Optional(5), #unavailable(OSX 10.52, *) {}") + val cpg = code("if let _ = Optional(5), #unavailable(OSX 10.52, *) {}") + val List(methodBlock) = cpg.method.nameExact("").block.l + // After desugaring: 0 in global method block + val List(tmpLocal) = methodBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + val List(ifControlStructure) = cpg.controlStructure.isIf.l ifControlStructure.whenTrue.astChildren shouldBe empty - val List(assignment, unavailable) = ifControlStructure.condition.astChildren.isCall.l - assignment.code shouldBe "let _ = Optional(5)" - assignment.name shouldBe Operators.assignment + + // Condition is desugared: { (0 = Optional(5)) != nil && #unavailable(...) } + val List(condBlock) = ifControlStructure.condition.isBlock.l + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(nilCheck, unavailable) = andCheck.argument.isCall.l + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"($tmpName = Optional(5)) != nil" unavailable.code shouldBe "#unavailable(OSX 10.52, *)" unavailable.name shouldBe "#unavailable" + + val List(assignment) = nilCheck.argument.assignment.l + assignment.code shouldBe s"$tmpName = Optional(5)" + ifControlStructure.whenFalse shouldBe empty } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala index 52f99a6f5309..213782a02c4f 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTests.scala @@ -3,7 +3,6 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite - import io.shiftleft.codepropertygraph.generated.* import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* @@ -20,12 +19,11 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("noConditionNoElse").block.l val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 1 guardIf.code shouldBe "guard {} else {}" guardIf.controlStructureType shouldBe ControlStructureTypes.IF guardIf.condition.code.l shouldBe List("0") - guardIf.whenTrue.astChildren.code.l shouldBe empty - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe empty + guardIf.whenTrue.astChildren shouldBe empty + methodBlock.astChildren.isControlStructure.whenFalse.astChildren shouldBe empty } "testGuard2" in { @@ -38,16 +36,26 @@ class GuardTests extends SwiftSrc2CpgSuite { | } | print(i) | i = i + 1 - |} + |} |""".stripMargin) val List(whileBlock) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).whenTrue.l val List(guardIf) = whileBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 1 - guardIf.code should startWith("guard i % 2 == 0") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.code.l shouldBe List("i % 2 == 0") - guardIf.whenTrue.astChildren.code.l shouldBe List("print(i)", "i = i + 1") - whileBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List("i = i + 1", "continue") + + val List(condition) = guardIf.condition.l + condition.code shouldBe "i % 2 == 0" + + val List(thenPrint, thenAssign) = guardIf.whenTrue.astChildren.isCall.l + thenPrint.code shouldBe "print(i)" + thenAssign.name shouldBe Operators.assignment + thenAssign.code shouldBe "i = i + 1" + + val List(elseAssign) = guardIf.whenFalse.astChildren.isCall.l + elseAssign.name shouldBe Operators.assignment + elseAssign.code shouldBe "i = i + 1" + + val List(elseContinue) = guardIf.whenFalse.astChildren.isControlStructure.l + elseContinue.code shouldBe "continue" } "testGuard3" in { @@ -63,18 +71,22 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkOddEven").block.l val List(call) = methodBlock.astChildren.isCall.l - call.order shouldBe 1 call.code shouldBe "var number = 24" + val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 2 - guardIf.code should startWith("guard number % 2 == 0") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.code.l shouldBe List("number % 2 == 0") - guardIf.whenTrue.code.l shouldBe List("print(\"Even Number\")") - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List( - "print(\"Odd Number\")", - "return" - ) + + val List(condition) = guardIf.condition.l + condition.code shouldBe "number % 2 == 0" + + val List(thenPrint) = guardIf.whenTrue.isCall.l + thenPrint.code shouldBe "print(\"Even Number\")" + + val List(elsePrint) = guardIf.whenFalse.astChildren.isCall.l + elsePrint.code shouldBe "print(\"Odd Number\")" + + val List(elseReturn) = guardIf.whenFalse.astChildren.isReturn.l + elseReturn.code shouldBe "return" } "testGuard4" in { @@ -90,18 +102,23 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkJobEligibility").block.l val List(call) = methodBlock.astChildren.isCall.l - call.order shouldBe 1 call.code shouldBe "var age = 33" + val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 2 - guardIf.code should startWith("guard age >= 18, age <= 40") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.astChildren.code.l shouldBe List("age >= 18", "age <= 40") - guardIf.whenTrue.code.l shouldBe List("print(\"You are eligible for this job\")") - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List( - "print(\"Not Eligible for Job\")", - "return" - ) + + val List(cond1, cond2) = guardIf.condition.astChildren.l + cond1.code shouldBe "age >= 18" + cond2.code shouldBe "age <= 40" + + val List(thenPrint) = guardIf.whenTrue.isCall.l + thenPrint.code shouldBe "print(\"You are eligible for this job\")" + + val List(elsePrint) = guardIf.whenFalse.astChildren.isCall.l + elsePrint.code shouldBe "print(\"Not Eligible for Job\")" + + val List(elseReturn) = guardIf.whenFalse.astChildren.isReturn.l + elseReturn.code shouldBe "return" } "testGuard5" in { @@ -116,20 +133,47 @@ class GuardTests extends SwiftSrc2CpgSuite { |} |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("checkAge").block.l - methodBlock.local.name.l shouldBe List("myAge", "age") + // After desugaring: age, 0 in method block, myAge in then block + val List(tmpLocal, ageLocal) = methodBlock.astChildren.isLocal.l + ageLocal.name shouldBe "age" + val tmpName = tmpLocal.name + tmpName should startWith("") + val List(call) = methodBlock.astChildren.isCall.l - call.order shouldBe 1 call.code shouldBe "var age: Int? = 22" + val List(guardIf) = methodBlock.astChildren.isControlStructure.l - guardIf.order shouldBe 2 - guardIf.code should startWith("guard let myAge = age") guardIf.controlStructureType shouldBe ControlStructureTypes.IF - guardIf.condition.code.l shouldBe List("let myAge = age") - guardIf.whenTrue.code.l shouldBe List("print(\"My age is \\(myAge)\")") - methodBlock.astChildren.isControlStructure.whenFalse.astChildren.code.l shouldBe List( - "print(\"Age is undefined\")", - "return" - ) + + // Condition is desugared to block with temp assignment and nil check + val List(condBlock) = guardIf.condition.isBlock.l + val List(nilCheck) = condBlock.astChildren.isCall.l + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"($tmpName = age) != nil" + + val List(assignment) = nilCheck.argument.isCall.l + assignment.name shouldBe Operators.assignment + assignment.code shouldBe s"$tmpName = age" + + val List(nilLit) = nilCheck.argument.isLiteral.l + nilLit.code shouldBe "nil" + + // Then block should have myAge local + val List(thenBlock) = guardIf.whenTrue.isBlock.l + val List(myAgeLocal) = thenBlock.astChildren.isLocal.l + myAgeLocal.name shouldBe "myAge" + + val List(myAgeAssign, printCall) = thenBlock.astChildren.isCall.l + myAgeAssign.name shouldBe Operators.assignment + myAgeAssign.code shouldBe s"myAge = $tmpName" + printCall.code shouldBe "print(\"My age is \\(myAge)\")" + + val List(elsePrint) = guardIf.whenFalse.astChildren.isCall.l + elsePrint.name shouldBe "print" + elsePrint.code shouldBe "print(\"Age is undefined\")" + + val List(elseReturn) = guardIf.whenFalse.astChildren.isReturn.l + elseReturn.code shouldBe "return" } "testGuard6" in { @@ -151,23 +195,38 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("multipleLinear").block.l val List(callA, callB) = methodBlock.astChildren.isCall.l - callA.order shouldBe 1 callA.code shouldBe "var a = true" - callB.order shouldBe 2 callB.code shouldBe "var b = true" + val List(guardIfA) = methodBlock.astChildren.isControlStructure.l - guardIfA.order shouldBe 3 - guardIfA.code should startWith("guard a else") guardIfA.controlStructureType shouldBe ControlStructureTypes.IF - guardIfA.condition.code.l shouldBe List("a") - guardIfA.whenFalse.astChildren.code.l shouldBe List("print(\"else a\")", "return") - guardIfA.whenTrue.astChildren.isCall.code.l shouldBe List("print(\"a\")") + + val List(condA) = guardIfA.condition.l + condA.code shouldBe "a" + + val List(elsePrintA) = guardIfA.whenFalse.astChildren.isCall.l + elsePrintA.code shouldBe "print(\"else a\")" + + val List(elseReturnA) = guardIfA.whenFalse.astChildren.isReturn.l + elseReturnA.code shouldBe "return" + + val List(thenPrintA) = guardIfA.whenTrue.astChildren.isCall.l + thenPrintA.code shouldBe "print(\"a\")" + val List(guardIfB) = guardIfA.whenTrue.astChildren.isControlStructure.l - guardIfB.code should startWith("guard b else") guardIfB.controlStructureType shouldBe ControlStructureTypes.IF - guardIfB.condition.code.l shouldBe List("b") - guardIfB.whenTrue.code.l shouldBe List("print(\"b\")") - guardIfB.whenFalse.astChildren.code.l shouldBe List("print(\"else b\")", "return") + + val List(condB) = guardIfB.condition.l + condB.code shouldBe "b" + + val List(thenPrintB) = guardIfB.whenTrue.isCall.l + thenPrintB.code shouldBe "print(\"b\")" + + val List(elsePrintB) = guardIfB.whenFalse.astChildren.isCall.l + elsePrintB.code shouldBe "print(\"else b\")" + + val List(elseReturnB) = guardIfB.whenFalse.astChildren.isReturn.l + elseReturnB.code shouldBe "return" } "testGuard7" in { @@ -189,23 +248,351 @@ class GuardTests extends SwiftSrc2CpgSuite { |""".stripMargin) val List(methodBlock) = cpg.method.nameExact("multipleNested").block.l val List(callA, callB) = methodBlock.astChildren.isCall.l - callA.order shouldBe 1 callA.code shouldBe "var a = true" - callB.order shouldBe 2 callB.code shouldBe "var b = true" + val List(guardIfA) = methodBlock.astChildren.isControlStructure.l - guardIfA.order shouldBe 3 - guardIfA.code should startWith("guard a else") guardIfA.controlStructureType shouldBe ControlStructureTypes.IF - guardIfA.condition.code.l shouldBe List("a") - guardIfA.whenFalse.astChildren.isCall.code.l shouldBe List("print(\"else a\")") - guardIfA.whenTrue.isCall.code.l shouldBe List("print(\"a\")") + + val List(condA) = guardIfA.condition.l + condA.code shouldBe "a" + + val List(thenPrintA) = guardIfA.whenTrue.isCall.l + thenPrintA.code shouldBe "print(\"a\")" + + val List(elsePrintA) = guardIfA.whenFalse.astChildren.isCall.l + elsePrintA.code shouldBe "print(\"else a\")" + val List(guardIfB) = guardIfA.whenFalse.astChildren.isControlStructure.l - guardIfB.code should startWith("guard b else") guardIfB.controlStructureType shouldBe ControlStructureTypes.IF - guardIfB.condition.code.l shouldBe List("b") - guardIfB.whenTrue.astChildren.code.l shouldBe List("print(\"b\")", "return") - guardIfB.whenFalse.astChildren.code.l shouldBe List("print(\"else b\")", "return") + + val List(condB) = guardIfB.condition.l + condB.code shouldBe "b" + + val List(thenPrintB) = guardIfB.whenTrue.astChildren.isCall.l + thenPrintB.code shouldBe "print(\"b\")" + + val List(thenReturnB) = guardIfB.whenTrue.astChildren.isReturn.l + thenReturnB.code shouldBe "return" + + val List(nestedElsePrintB) = guardIfB.whenFalse.astChildren.isCall.l + nestedElsePrintB.code shouldBe "print(\"else b\")" + + val List(nestedElseReturnB) = guardIfB.whenFalse.astChildren.isReturn.l + nestedElseReturnB.code shouldBe "return" + } + + "testGuardLet" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | guard let value = optionalValue else { + | return + | } + | print(value) + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal) = methodBlock.astChildren.isLocal.l + val tmpName = tmpLocal.name + tmpName should startWith("") + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { (0 = optionalValue) != nil } + val List(condBlock) = guardIf.condition.isBlock.l + val List(condCheck) = condBlock.astChildren.isCall.l + condCheck.name shouldBe Operators.notEquals + condCheck.code shouldBe s"($tmpName = optionalValue) != nil" + + val List(condAssign) = condCheck.argument.isCall.l + condAssign.name shouldBe Operators.assignment + condAssign.code shouldBe s"$tmpName = optionalValue" + + val List(tmpArg, optValueArg) = condAssign.argument.l + tmpArg.code shouldBe tmpName + optValueArg.code shouldBe "optionalValue" + + val List(nilLit) = condCheck.argument.isLiteral.l + nilLit.code shouldBe "nil" + + // Then block: { let value = 0; print(value) } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(valueLocal) = thenBlock.astChildren.isLocal.l + valueLocal.name shouldBe "value" + + val List(thenAssign, printCall) = thenBlock.astChildren.isCall.l + thenAssign.name shouldBe Operators.assignment + thenAssign.code shouldBe s"value = $tmpName" + printCall.code shouldBe "print(value)" + + val List(valueArg, tmpRefArg) = thenAssign.argument.l + valueArg.code shouldBe "value" + tmpRefArg.code shouldBe tmpName + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" + } + + "testGuardLetWithoutInitializer" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | guard let optionalValue else { + | return + | } + | print(optionalValue) + |} + |""".stripMargin) + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // For "guard let optionalValue" without explicit initializer: + // Condition: optionalValue != nil (direct comparison, no block) + val List(condCheck) = guardIf.condition.isCall.l + condCheck.name shouldBe Operators.notEquals + condCheck.code shouldBe "optionalValue != nil" + condCheck.argument(1).code shouldBe "optionalValue" + condCheck.argument(2).code shouldBe "nil" + + // Then branch: print(optionalValue) but no new local or assignment for optionalValue + inside(guardIf.whenTrue.l) { case List(thenBlock) => + thenBlock.astChildren.isLocal.nameExact("optionalValue") shouldBe empty + val assignments = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + assignments.filter(_.argument(1).code == "optionalValue") shouldBe empty + } + } + + "testGuardLetMultipleBindings" in { + val cpg = code(""" + |func test() { + | guard let a = foo(), let b = bar() else { + | return + | } + | print(a, b) + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1 in method block + val List(tmp0Local, tmp1Local) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" + val tmp1Name = tmp1Local.name + tmp1Name shouldBe "1" + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { ((0 = foo()) != nil) && ((1 = bar()) != nil) } + val List(condBlock) = guardIf.condition.isBlock.l + + val List(andCheck) = condBlock.astChildren.isCall.l + andCheck.name shouldBe Operators.logicalAnd + + val List(tmp0Check, tmp1Check) = andCheck.argument.isCall.l + tmp0Check.name shouldBe Operators.notEquals + tmp0Check.code shouldBe s"($tmp0Name = foo()) != nil" + tmp1Check.name shouldBe Operators.notEquals + tmp1Check.code shouldBe s"($tmp1Name = bar()) != nil" + + val List(tmp0Assign) = tmp0Check.argument.isCall.l + tmp0Assign.name shouldBe Operators.assignment + tmp0Assign.code shouldBe s"$tmp0Name = foo()" + + val List(tmp0Nil) = tmp0Check.argument.isLiteral.l + tmp0Nil.code shouldBe "nil" + + val List(tmp1Assign) = tmp1Check.argument.isCall.l + tmp1Assign.name shouldBe Operators.assignment + tmp1Assign.code shouldBe s"$tmp1Name = bar()" + + val List(tmp1Nil) = tmp1Check.argument.isLiteral.l + tmp1Nil.code shouldBe "nil" + + // Then block: { let a = 0; let b = 1; print(a, b) } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(aLocal, bLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" + + val List(aAssignment, bAssignment, printCall) = thenBlock.astChildren.isCall.l + aAssignment.name shouldBe Operators.assignment + aAssignment.code shouldBe s"a = $tmp0Name" + bAssignment.name shouldBe Operators.assignment + bAssignment.code shouldBe s"b = $tmp1Name" + printCall.code shouldBe "print(a, b)" + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" + } + + "testGuardLetThreeBindings" in { + val cpg = code(""" + |func test() { + | guard let a = foo(), let b = bar(), let c = baz() else { + | return + | } + | print(a, b, c) + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1, 2 in method block + val List(tmp0Local, tmp1Local, tmp2Local) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" + val tmp1Name = tmp1Local.name + tmp1Name shouldBe "1" + val tmp2Name = tmp2Local.name + tmp2Name shouldBe "2" + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { (((0 = foo()) != nil) && ((1 = bar()) != nil)) && ((2 = baz()) != nil) } + // The && operators are left-associated: (check0 && check1) && check2 + val List(condBlock) = guardIf.condition.isBlock.l + + val List(outerAndCheck) = condBlock.astChildren.isCall.l + outerAndCheck.name shouldBe Operators.logicalAnd + outerAndCheck.code shouldBe s"((($tmp0Name = foo()) != nil) && (($tmp1Name = bar()) != nil)) && (($tmp2Name = baz()) != nil)" + + val List(innerAndCheck, tmp2Check) = outerAndCheck.argument.isCall.l + innerAndCheck.name shouldBe Operators.logicalAnd + innerAndCheck.code shouldBe s"(($tmp0Name = foo()) != nil) && (($tmp1Name = bar()) != nil)" + + tmp2Check.name shouldBe Operators.notEquals + tmp2Check.code shouldBe s"($tmp2Name = baz()) != nil" + + val List(tmp0Check, tmp1Check) = innerAndCheck.argument.isCall.l + tmp0Check.name shouldBe Operators.notEquals + tmp0Check.code shouldBe s"($tmp0Name = foo()) != nil" + tmp1Check.name shouldBe Operators.notEquals + tmp1Check.code shouldBe s"($tmp1Name = bar()) != nil" + + // Then block: { let a = 0; let b = 1; let c = 2; print(a, b, c) } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(aLocal, bLocal, cLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" + cLocal.name shouldBe "c" + + val List(aAssignment, bAssignment, cAssignment, printCall) = thenBlock.astChildren.isCall.l + aAssignment.name shouldBe Operators.assignment + aAssignment.code shouldBe s"a = $tmp0Name" + bAssignment.name shouldBe Operators.assignment + bAssignment.code shouldBe s"b = $tmp1Name" + cAssignment.name shouldBe Operators.assignment + cAssignment.code shouldBe s"c = $tmp2Name" + printCall.code shouldBe "print(a, b, c)" + } + + "testGuardLetMixedWithAndWithoutInitializer" in { + val cpg = code(""" + |func test(existing: Int?) { + | guard let a = foo(), let existing else { + | return + | } + | print(a, existing) + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { (0 = foo()) != nil && existing != nil } + val List(condBlock) = guardIf.condition.isBlock.l + + val List(andCheck) = condBlock.astChildren.isCall.l + andCheck.name shouldBe Operators.logicalAnd + + val List(tmpCheck, existingCheck) = andCheck.argument.isCall.l + tmpCheck.name shouldBe Operators.notEquals + tmpCheck.code shouldBe s"($tmpName = foo()) != nil" + existingCheck.name shouldBe Operators.notEquals + existingCheck.code shouldBe "existing != nil" + + val List(tmpAssign) = tmpCheck.argument.isCall.l + tmpAssign.name shouldBe Operators.assignment + tmpAssign.code shouldBe s"$tmpName = foo()" + + val List(tmpNil) = tmpCheck.argument.isLiteral.l + tmpNil.code shouldBe "nil" + + val List(existingIdent) = existingCheck.argument.isIdentifier.l + existingIdent.name shouldBe "existing" + existingIdent.code shouldBe "existing" + + val List(existingNil) = existingCheck.argument.isLiteral.l + existingNil.code shouldBe "nil" + + // Then block: { let a = 0; print(a, existing) } (no assignment for 'existing') + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(aLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + + val List(thenAssign, printCall) = thenBlock.astChildren.isCall.l + thenAssign.name shouldBe Operators.assignment + thenAssign.code shouldBe s"a = $tmpName" + printCall.code shouldBe "print(a, existing)" + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" + } + + "testGuardLetWithOtherConditions" in { + val cpg = code(""" + |func test(flag: Bool) { + | guard let a = foo(), flag else { + | return + | } + | print(a) + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal) = methodBlock.astChildren.isLocal.nameNot("self").l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + + val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { ((0 = foo()) != nil) && flag } + val List(condBlock) = guardIf.condition.isBlock.l + val List(andCheck) = condBlock.astChildren.isCall.l + andCheck.name shouldBe Operators.logicalAnd + + inside(andCheck.argument.l) { + case List(nilCheck: Call, flag: Identifier) => + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"($tmpName = foo()) != nil" + flag.name shouldBe "flag" + flag.code shouldBe "flag" + case List(flag: Identifier, nilCheck: Call) => + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"($tmpName = foo()) != nil" + flag.name shouldBe "flag" + flag.code shouldBe "flag" + } + + // Then block: { let a = 0; print(a) } + val List(thenBlock) = guardIf.whenTrue.isBlock.l + + val List(aLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + + val List(thenAssign, printCall) = thenBlock.astChildren.isCall.l + thenAssign.name shouldBe Operators.assignment + thenAssign.code shouldBe s"a = $tmpName" + printCall.name shouldBe "print" + printCall.code shouldBe "print(a)" + + val List(elseReturn) = guardIf.whenFalse.l + elseReturn.code shouldBe "return" } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala index f7d0ec63bbef..3073e58457c4 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/GuardTopLevelTests.scala @@ -5,6 +5,7 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite import io.shiftleft.codepropertygraph.generated.ControlStructureTypes import io.shiftleft.codepropertygraph.generated.Operators +import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class GuardTopLevelTests extends SwiftSrc2CpgSuite { @@ -14,16 +15,55 @@ class GuardTopLevelTests extends SwiftSrc2CpgSuite { val cpg = code(""" |let a: Int? = 1 |guard let b = a else {} + |print(b) |""".stripMargin) val List(globalBlock) = cpg.method.nameExact("").block.l - val List(localA) = globalBlock.local.nameExact("a").l + + // After desugaring: 0 and `a` in global block + val List(tmpLocal, localA) = globalBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + localA.name shouldBe "a" localA.typeFullName shouldBe "Swift.Int" - val List(localB) = globalBlock.local.nameExact("b").l - localB.typeFullName shouldBe "ANY" - val assigns = cpg.call.nameExact(Operators.assignment).code.l - assigns shouldBe List("let a: Int? = 1", "let b = a") - val List(guardIf) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l - guardIf.code should startWith("guard let b = a else") + + // After desugaring, b is in the guard's then block, not the global block + val List(guardIf: ControlStructure) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + guardIf.code shouldBe "guard let b = a else {}" + guardIf.controlStructureType shouldBe ControlStructureTypes.IF + + // Check that desugaring created the temp variable and nil check in condition + val List(condBlock) = guardIf.condition.isBlock.l + + val List(nilCheck) = condBlock.astChildren.isCall.l + nilCheck.name shouldBe Operators.notEquals + nilCheck.code shouldBe s"($tmpName = a) != nil" + + val List(assignment) = nilCheck.argument.isCall.l + assignment.name shouldBe Operators.assignment + assignment.code shouldBe s"$tmpName = a" + + val List(tmpArg, aArg) = assignment.argument.l + tmpArg.code shouldBe tmpName + aArg.code shouldBe "a" + + val List(nilLit) = nilCheck.argument.isLiteral.l + nilLit.code shouldBe "nil" + + // Check that b local is in the then block along with code that follows the guard + val List(thenBlock) = guardIf.whenTrue.isBlock.l + val List(localB) = thenBlock.local.l + localB.name shouldBe "b" + + // Verify the print(b) call is also in the then block (code following guard) + val List(bAssign, printCall) = thenBlock.astChildren.isCall.l + bAssign.name shouldBe Operators.assignment + bAssign.code shouldBe s"b = $tmpName" + printCall.code shouldBe "print(b)" + + val List(bArg) = printCall.argument.l + bArg.code shouldBe "b" + + guardIf.whenFalse.astChildren shouldBe empty } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala index ce6d1a5ca332..8cab869386a9 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/StatementTests.scala @@ -1,24 +1,51 @@ package io.joern.swiftsrc2cpg.passes.ast import io.joern.swiftsrc2cpg.testfixtures.SwiftSrc2CpgSuite - import io.shiftleft.codepropertygraph.generated.* -import io.shiftleft.codepropertygraph.generated.nodes.* import io.shiftleft.semanticcpg.language.* class StatementTests extends SwiftSrc2CpgSuite { "StatementTests" should { - "testIf" in { + "testIfLet" in { val cpg = code(""" - |if let baz {} - |if let self = self {} + |func test() { + | if let baz = optionalValue { + | print(baz) + | } + |} |""".stripMargin) - val ifs = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).code.l - ifs shouldBe List("if let baz {}", "if let self = self {}") - cpg.local.name.l should contain allOf ("baz", "self") - cpg.call.codeExact("let self = self").l should not be empty + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal, optionalValueLocal) = methodBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + optionalValueLocal.name shouldBe "optionalValue" + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: desugared to { (0 = optionalValue) != nil } + val List(condBlock) = ifNode.condition.isBlock.l + + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"($tmpName = optionalValue) != nil" + + val List(condAssign) = condCheck.argument.assignment.l + condAssign.code shouldBe s"$tmpName = optionalValue" + condAssign.argument(1).code shouldBe tmpName + condAssign.argument(2).code shouldBe "optionalValue" + + // Then block: { let baz = 0; print(baz) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + val List(bazLocal) = thenBlock.astChildren.isLocal.l + bazLocal.name shouldBe "baz" + + val List(thenAssign) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + thenAssign.code shouldBe s"baz = $tmpName" + thenAssign.argument(1).code shouldBe "baz" + thenAssign.argument(2).code shouldBe tmpName } "testDoCatch" in { @@ -120,6 +147,334 @@ class StatementTests extends SwiftSrc2CpgSuite { ifNode.code should startWith("if true") } + "testIfLetWithoutInitializer" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | if let optionalValue { + | print(optionalValue) + | } + |} + |""".stripMargin) + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // For "if let optionalValue" without explicit initializer: + // Condition: optionalValue != nil (direct comparison, no block) + val List(condCheck) = ifNode.condition.isCall.l + condCheck.name shouldBe Operators.notEquals + condCheck.code shouldBe "optionalValue != nil" + condCheck.argument(1).code shouldBe "optionalValue" + condCheck.argument(2).code shouldBe "nil" + + // Then branch: just the original body, no new local or assignment for optionalValue + // The body might be a block or another structure depending on CodeBlockSyntax handling + val thenNodes = ifNode.whenTrue.l + thenNodes should not be empty + + // Verify no new local named "optionalValue" was created in the then branch + val localsInThen = thenNodes.flatMap(_.ast.isLocal.nameExact("optionalValue").l) + localsInThen shouldBe empty + + // Verify no assignment to optionalValue in the then branch + val assignmentsInThen = thenNodes.flatMap(_.ast.isCall.nameExact(Operators.assignment).l) + val optionalValueAssignments = assignmentsInThen.filter(_.argument(1).code == "optionalValue") + optionalValueAssignments shouldBe empty + } + + "testWhileLet" in { + val cpg = code(""" + |func test() { + | while let item = iterator.next() { + | print(item) + | } + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmpLocal, iteratorLocal) = methodBlock.local.l + val tmpName = tmpLocal.name + tmpName shouldBe "0" + iteratorLocal.name shouldBe "iterator" + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // Condition: desugared to { (0 = iterator.next()) != nil } + val List(condBlock) = whileNode.condition.isBlock.l + + val List(condCheck) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + condCheck.code shouldBe s"($tmpName = iterator.next()) != nil" + + val List(condAssign) = condCheck.argument.assignment.l + condAssign.code shouldBe s"$tmpName = iterator.next()" + condAssign.argument(1).code shouldBe tmpName + condAssign.argument(2).code shouldBe "iterator.next()" + + // Loop body: { let item = 0; print(item) } + val List(bodyBlock) = whileNode.whenTrue.isBlock.l + + val List(itemLocal) = bodyBlock.astChildren.isLocal.l + itemLocal.name shouldBe "item" + + val List(bodyAssign) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).l + bodyAssign.code shouldBe s"item = $tmpName" + bodyAssign.argument(1).code shouldBe "item" + bodyAssign.argument(2).code shouldBe tmpName + } + + "testWhileLetWithoutInitializer" in { + val cpg = code(""" + |func test(optionalValue: Int?) { + | while let optionalValue { + | print(optionalValue) + | } + |} + |""".stripMargin) + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // For "while let optionalValue" without explicit initializer: + // Condition: optionalValue != nil (direct comparison, no block) + val List(condCheck) = whileNode.condition.isCall.l + condCheck.name shouldBe Operators.notEquals + condCheck.code shouldBe "optionalValue != nil" + condCheck.argument(1).code shouldBe "optionalValue" + condCheck.argument(2).code shouldBe "nil" + + // Loop body: just the original body, no new local or assignment for optionalValue + val bodyNodes = whileNode.whenTrue.l + bodyNodes should not be empty + + // Verify no new local named "optionalValue" was created in the loop body + val localsInBody = bodyNodes.flatMap(_.ast.isLocal.nameExact("optionalValue").l) + localsInBody shouldBe empty + + // Verify no assignment to optionalValue in the loop body + val assignmentsInBody = bodyNodes.flatMap(_.ast.isCall.nameExact(Operators.assignment).l) + val optionalValueAssignments = assignmentsInBody.filter(_.argument(1).code == "optionalValue") + optionalValueAssignments shouldBe empty + } + + "testIfLetMultiple" in { + val cpg = code(""" + |func test() { + | if let a = Optional(1), let b = Optional(2) { + | print(a, b) + | } + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1 in method block + val List(tmp0Local, tmp1Local) = methodBlock.local.nameNot("self").l + val tmp0Name = tmp0Local.name + val tmp1Name = tmp1Local.name + tmp0Name shouldBe "0" + tmp1Name shouldBe "1" + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { (0 = Optional(1)) != nil && (1 = Optional(2)) != nil } + val List(condBlock) = ifNode.condition.isBlock.l + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"($tmp0Name = Optional(1)) != nil" + check2.code shouldBe s"($tmp1Name = Optional(2)) != nil" + + val List(assign1, assign2) = andCheck.argument.isCall.argument.assignment.l + assign1.code shouldBe s"$tmp0Name = Optional(1)" + assign2.code shouldBe s"$tmp1Name = Optional(2)" + + // Then block: { a = 0; b = 1; print(a, b) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + val List(aLocal, bLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" + + val List(unwrapA, unwrapB) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + unwrapA.code shouldBe s"a = $tmp0Name" + unwrapB.code shouldBe s"b = $tmp1Name" + } + + "testIfLetMixed" in { + val cpg = code(""" + |func test(opt2: Int?) { + | if let a = Optional(1), let opt2 { + | print(a, opt2) + | } + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + val List(tmp0Local) = methodBlock.local.nameNot("self").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { (0 = Optional(1)) != nil && opt2 != nil } + val List(condBlock) = ifNode.condition.isBlock.l + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"($tmp0Name = Optional(1)) != nil" + check2.code shouldBe "opt2 != nil" + + val List(assign1) = check1.argument.assignment.l + assign1.code shouldBe s"$tmp0Name = Optional(1)" + + // Then block: { a = 0; print(a, opt2) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + val List(aLocal) = thenBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + + val List(unwrapA) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).l + unwrapA.code shouldBe s"a = $tmp0Name" + } + + "testWhileLetMultiple" in { + val cpg = code(""" + |func test() { + | while let a = iterator1.next(), let b = iterator2.next() { + | print(a, b) + | } + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0, 1 in method block + val List(tmp0Local, tmp1Local) = methodBlock.local.nameNot("iterator1", "iterator2").l + val tmp0Name = tmp0Local.name + val tmp1Name = tmp1Local.name + tmp0Name shouldBe "0" + tmp1Name shouldBe "1" + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // Condition: { (0 = iterator1.next()) != nil && (1 = iterator2.next()) != nil } + val List(condBlock) = whileNode.condition.isBlock.l + + val List(andCheck) = condBlock.astChildren.isCall.nameExact(Operators.logicalAnd).l + val List(check1, check2) = andCheck.argument.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"($tmp0Name = iterator1.next()) != nil" + check2.code shouldBe s"($tmp1Name = iterator2.next()) != nil" + + val List(assign1, assign2) = andCheck.argument.isCall.argument.assignment.l + assign1.code shouldBe s"$tmp0Name = iterator1.next()" + assign2.code shouldBe s"$tmp1Name = iterator2.next()" + + // Loop body: { a = 0; b = 1; print(a, b) } + val List(bodyBlock) = whileNode.whenTrue.isBlock.l + + val List(aLocal, bLocal) = bodyBlock.astChildren.isLocal.l + aLocal.name shouldBe "a" + bLocal.name shouldBe "b" + + val List(unwrapA, unwrapB) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).l + unwrapA.code shouldBe s"a = $tmp0Name" + unwrapB.code shouldBe s"b = $tmp1Name" + } + + "testIfLetMixedWithTuplePattern" in { + val cpg = code(""" + |func test() { + | if let a = foo(), let (b, c) = bar() { + | print(a, b, c) + | } + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + // FIXME: because optional tuple bindings are not handled yet the default handling + // (which is wrong) will not create correct locals and the arguments to print end up + // creating locals in the method block. + val List(tmp0Local) = methodBlock.local.nameNot("self", "b", "c").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" + + val List(ifNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.IF).l + + // Condition: { (0 = foo()) != nil } (tuple pattern excluded from condition) + val List(condBlock) = ifNode.condition.isBlock.l + + val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"($tmp0Name = foo()) != nil" + + val List(assign1) = check1.argument.assignment.l + assign1.code shouldBe s"$tmp0Name = foo()" + assign1.argument(1).code shouldBe tmp0Name + assign1.argument(2).code shouldBe "foo()" + + // Then block: { a = 0; let (b, c) = bar(); print(a, b, c) } + val List(thenBlock) = ifNode.whenTrue.isBlock.l + + // First child: unwrapping assignment for 'a' + val List(aLocal) = thenBlock.astChildren.isLocal.nameExact("a").l + aLocal.name shouldBe "a" + + val List(unwrapA) = thenBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp0Name").l + unwrapA.argument(1).code shouldBe "a" + unwrapA.argument(2).code shouldBe tmp0Name + + // Second child: tuple binding block for (b, c) = bar() + // The tuple binding creates a nested block with tmp assignment + tuple destructuring + val List(tupleBindingBlock) = thenBlock.astChildren.isBlock.l + + // Inside the tuple binding block, there should be an assignment involving bar() + val List(barAssignment) = tupleBindingBlock.astChildren.isCall.nameExact(Operators.assignment).code(".*bar.*").l + barAssignment.code should include("bar()") + } + + "testWhileLetMixedWithTuplePattern" in { + val cpg = code(""" + |func test() { + | while let a = foo(), let (b, c) = bar() { + | print(a, b, c) + | } + |} + |""".stripMargin) + val List(methodBlock) = cpg.method.nameExact("test").block.l + // After desugaring: 0 in method block + // FIXME: because optional tuple bindings are not handled yet the default handling + // (which is wrong) will not create correct locals and the arguments to print end up + // creating locals in the method block. + val List(tmp0Local) = methodBlock.local.nameNot("self", "b", "c").l + val tmp0Name = tmp0Local.name + tmp0Name shouldBe "0" + + val List(whileNode) = cpg.controlStructure.controlStructureType(ControlStructureTypes.WHILE).l + + // Condition: { (0 = foo()) != nil } (tuple pattern excluded from condition) + val List(condBlock) = whileNode.condition.isBlock.l + + val List(check1) = condBlock.astChildren.isCall.nameExact(Operators.notEquals).l + check1.code shouldBe s"($tmp0Name = foo()) != nil" + + val List(assign1) = check1.argument.assignment.l + assign1.code shouldBe s"$tmp0Name = foo()" + assign1.argument(1).code shouldBe tmp0Name + assign1.argument(2).code shouldBe "foo()" + + // Loop body: { a = 0; let (b, c) = bar(); print(a, b, c) } + val List(bodyBlock) = whileNode.whenTrue.isBlock.l + + // First child: unwrapping assignment for 'a' + val List(aLocal) = bodyBlock.astChildren.isLocal.nameExact("a").l + aLocal.name shouldBe "a" + + val List(unwrapA) = bodyBlock.astChildren.isCall.nameExact(Operators.assignment).codeExact(s"a = $tmp0Name").l + unwrapA.argument(1).code shouldBe "a" + unwrapA.argument(2).code shouldBe tmp0Name + + // Second child: tuple binding block for (b, c) = bar() + // The tuple binding creates a nested block with tmp assignment + tuple destructuring + val List(tupleBindingBlock) = bodyBlock.astChildren.isBlock.l + + // Inside the tuple binding block, there should be an assignment involving bar() + val List(barAssignment) = tupleBindingBlock.astChildren.isCall.nameExact(Operators.assignment).code(".*bar.*").l + barAssignment.code should include("bar()") + } + } } diff --git a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala index daeed30abc40..af636bc19b1b 100644 --- a/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala +++ b/joern-cli/frontends/swiftsrc2cpg/src/test/scala/io/joern/swiftsrc2cpg/passes/ast/WhileTests.scala @@ -68,7 +68,8 @@ class WhileTests extends SwiftSrc2CpgSuite { inside(controlStruct.condition.l) { case List(cndNode: Literal) => cndNode.code shouldBe "true" } - controlStruct.whenTrue.code.l shouldBe List("""print("Endless Loop")""") + controlStruct.condition.code.l shouldBe List("true") + controlStruct.doBodyOut.code.l shouldBe List("""print("Endless Loop")""") controlStruct.lineNumber shouldBe Some(2) controlStruct.columnNumber shouldBe Some(1) } @@ -94,7 +95,8 @@ class WhileTests extends SwiftSrc2CpgSuite { n.name shouldBe "n" n.order shouldBe 2 } - controlStruct.whenTrue.astChildren.code.l shouldBe List("print(i)", "i = i + 1") + controlStruct.condition.code.l shouldBe List("i <= n") + controlStruct.doBodyOut.astChildren.code.l shouldBe List("print(i)", "i = i + 1") controlStruct.lineNumber shouldBe Some(2) controlStruct.columnNumber shouldBe Some(1) }