diff --git a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala index 3ca98e660b7c..c2f8c3e5430f 100644 --- a/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala +++ b/joern-cli/frontends/kotlin2cpg/src/main/scala/io/joern/kotlin2cpg/ast/AstForStatementsCreator.scala @@ -528,15 +528,17 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { else astForTryAsExpression(expr, argIdx, argNameMaybe, annotations) } - def astForBreak(expr: KtBreakExpression): Ast = { - val node = controlStructureNode(expr, ControlStructureTypes.BREAK, code(expr)) - Ast(node) - } + def astForBreak(expr: KtBreakExpression): Ast = + Option(expr.getLabelName) match { + case Some(labelName) => breakAst(expr, code(expr), labelName) + case None => breakAst(expr, code(expr), 1) + } - def astForContinue(expr: KtContinueExpression): Ast = { - val node = controlStructureNode(expr, ControlStructureTypes.CONTINUE, code(expr)) - Ast(node) - } + def astForContinue(expr: KtContinueExpression): Ast = + Option(expr.getLabelName) match { + case Some(labelName) => continueAst(expr, code(expr), labelName) + case None => continueAst(expr, code(expr), 1) + } def astForThrowExpression( expr: KtThrowExpression, diff --git a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala index 0b244d330225..ed8a996b0cd0 100644 --- a/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala +++ b/joern-cli/frontends/kotlin2cpg/src/test/scala/io/joern/kotlin2cpg/querying/ControlStructureTests.scala @@ -1,7 +1,7 @@ package io.joern.kotlin2cpg.querying import io.joern.kotlin2cpg.testfixtures.KotlinCode2CpgFixture -import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, ControlStructure, Identifier, Local} +import io.shiftleft.codepropertygraph.generated.nodes.{Block, Call, ControlStructure, Identifier, JumpLabel, Local} import io.shiftleft.codepropertygraph.generated.{ControlStructureTypes, DispatchTypes, EdgeTypes, Operators} import io.shiftleft.semanticcpg.language.* import io.shiftleft.codepropertygraph.generated.nodes.Literal @@ -717,4 +717,79 @@ class ControlStructureTests extends KotlinCode2CpgFixture(withOssDataflow = fals } // TODO: also add test for the loop range, when it is with downTo or whatever + + "CPG for code with labeled break statement" should { + val cpg = code(""" + |package mypkg + |fun foo() { + | outer@ for (i in 0..10) { + | for (j in 0..10) { + | if (j == 5) break@outer + | } + | } + |}""".stripMargin) + + "should contain a BREAK with a JUMP_LABEL child" in { + val List(breakNode) = cpg.controlStructure.isBreak.l + breakNode.code shouldBe "break@outer" + val List(jumpLabel) = breakNode.astChildren.collectAll[JumpLabel].l + jumpLabel.name shouldBe "outer" + jumpLabel.order shouldBe 1 + } + + "should have a JUMP_ARGUMENT edge from break to the JUMP_LABEL" in { + val List(breakNode) = cpg.controlStructure.isBreak.l + val List(jumpLabel) = breakNode.jumpArgumentOut.collectAll[JumpLabel].l + jumpLabel.name shouldBe "outer" + } + } + + "CPG for code with labeled continue statement" should { + val cpg = code(""" + |package mypkg + |fun foo() { + | outer@ for (i in 0..10) { + | for (j in 0..10) { + | if (j == 3) continue@outer + | } + | } + |}""".stripMargin) + + "should contain a CONTINUE with a JUMP_LABEL child" in { + val List(continueNode) = cpg.controlStructure.isContinue.l + continueNode.code shouldBe "continue@outer" + val List(jumpLabel) = continueNode.astChildren.collectAll[JumpLabel].l + jumpLabel.name shouldBe "outer" + jumpLabel.order shouldBe 1 + } + + "should have a JUMP_ARGUMENT edge from continue to the JUMP_LABEL" in { + val List(continueNode) = cpg.controlStructure.isContinue.l + val List(jumpLabel) = continueNode.jumpArgumentOut.collectAll[JumpLabel].l + jumpLabel.name shouldBe "outer" + } + } + + "CPG for code with unlabeled break/continue" should { + val cpg = code(""" + |package mypkg + |fun foo() { + | for (i in 0..10) { + | if (i == 5) break + | if (i == 3) continue + | } + |}""".stripMargin) + + "should have no JUMP_LABEL child and no JUMP_ARGUMENT edge on break" in { + val List(breakNode) = cpg.controlStructure.isBreak.l + breakNode.astChildren.collectAll[JumpLabel].size shouldBe 0 + breakNode.jumpArgumentOut.size shouldBe 0 + } + + "should have no JUMP_LABEL child and no JUMP_ARGUMENT edge on continue" in { + val List(continueNode) = cpg.controlStructure.isContinue.l + continueNode.astChildren.collectAll[JumpLabel].size shouldBe 0 + continueNode.jumpArgumentOut.size shouldBe 0 + } + } } diff --git a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala index a0db6cdfd2b1..4c37c4327375 100644 --- a/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala +++ b/joern-cli/frontends/x2cpg/src/main/scala/io/joern/x2cpg/AstCreatorBase.scala @@ -245,6 +245,49 @@ abstract class AstCreatorBase[Node, NodeProcessor](filename: String)(implicit wi } } + /** Creates an AST for a labeled break statement. A `JumpLabel` child is created at order 1 and connected to the break + * node via a `JUMP_ARGUMENT` edge. + */ + def breakAst(node: Node, codeStr: String, labelName: String): Ast = + labeledJumpAst(node, ControlStructureTypes.BREAK, codeStr, labelName) + + /** Creates an AST for a level-based break statement. A `JumpLabel` child holding the level number is created at order + * 1 and connected to the break node via a `JUMP_ARGUMENT` edge. + */ + def breakAst(node: Node, codeStr: String, levels: Int): Ast = + labeledJumpAst(node, ControlStructureTypes.BREAK, codeStr, levels.toString) + + /** Creates an AST for a labeled continue statement. A `JumpLabel` child is created at order 1 and connected to the + * continue node via a `JUMP_ARGUMENT` edge. + */ + def continueAst(node: Node, codeStr: String, labelName: String): Ast = + labeledJumpAst(node, ControlStructureTypes.CONTINUE, codeStr, labelName) + + /** Creates an AST for a level-based continue statement. A `JumpLabel` child holding the level number is created at + * order 1 and connected to the continue node via a `JUMP_ARGUMENT` edge. + */ + def continueAst(node: Node, codeStr: String, levels: Int): Ast = + labeledJumpAst(node, ControlStructureTypes.CONTINUE, codeStr, levels.toString) + + private def labeledJumpAst(node: Node, jumpType: String, codeStr: String, labelName: String): Ast = { + val jumpNode = NewControlStructure() + .parserTypeName(node.getClass.getSimpleName) + .controlStructureType(jumpType) + .code(codeStr) + .lineNumber(line(node)) + .columnNumber(column(node)) + val jumpLabelNode = NewJumpLabel() + .parserTypeName(node.getClass.getSimpleName) + .name(labelName) + .code(labelName) + .lineNumber(line(node)) + .columnNumber(column(node)) + .order(1) + Ast(jumpNode) + .withChild(Ast(jumpLabelNode)) + .withJumpArgumentEdge(jumpNode, jumpLabelNode) + } + /** For the given try body, catch ASTs and finally AST, create a try-catch-finally AST with orders set correctly for * the ossdataflow engine. */