From 5c159832ce47fad7765d5e1643769ead475315b1 Mon Sep 17 00:00:00 2001 From: allsmog Date: Mon, 30 Mar 2026 08:57:54 -0700 Subject: [PATCH] [gosrc2cpg] Add defer/go/select/send statements, fix fallthrough and tuple returns - Add DeferStmt, GoStmt, SelectStmt, SendStmt, CommClause parser node types - Add handler methods for all 5 new statement types - Fix fallthrough to produce proper control structure node - Fix tuple return types to correctly represent (type1, type2) - Remove stale TODO comments --- .../astcreation/AstForFunctionsCreator.scala | 10 +- .../astcreation/AstForLambdaCreator.scala | 10 +- .../astcreation/AstForStatementsCreator.scala | 51 ++++++++- .../astcreation/CommonCacheBuilder.scala | 10 +- .../io/joern/gosrc2cpg/parser/ParserAst.scala | 8 ++ .../io/joern/gosrc2cpg/utils/Constants.scala | 1 + .../go2cpg/passes/ast/ConcurrencyTests.scala | 103 ++++++++++++++++++ .../joern/go2cpg/passes/ast/MethodTests.scala | 30 +++-- .../joern/go2cpg/passes/ast/SwitchTests.scala | 3 +- 9 files changed, 202 insertions(+), 24 deletions(-) create mode 100644 joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConcurrencyTests.scala diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForFunctionsCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForFunctionsCreator.scala index f976cb403b34..2036ac9dbe3d 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForFunctionsCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForFunctionsCreator.scala @@ -15,10 +15,12 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th def astForFuncDecl(funcDecl: ParserNodeInfo): Seq[Ast] = { val methodMetadata = processFuncDecl(funcDecl.json) - // TODO: handle multiple return type or tuple (int, int) - val (returnTypeStr, returnTypeInfo) = - getReturnType(funcDecl.json(ParserKeys.Type), methodMetadata.genericTypeMethodMap).headOption - .getOrElse((Defines.voidTypeName, funcDecl)) + val returnTypes = getReturnType(funcDecl.json(ParserKeys.Type), methodMetadata.genericTypeMethodMap) + val (returnTypeStr, returnTypeInfo) = returnTypes match { + case Seq() => (Defines.voidTypeName, funcDecl) + case Seq(single) => single + case multiple => (s"(${multiple.map(_._1).mkString(", ")})", multiple.head._2) + } val methodReturn = methodReturnNode(returnTypeInfo, returnTypeStr) val methodNode_ = methodNode( funcDecl, diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForLambdaCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForLambdaCreator.scala index 5062b3a8407f..ef06d6b23fb2 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForLambdaCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForLambdaCreator.scala @@ -57,10 +57,12 @@ trait AstForLambdaCreator(implicit withSchemaValidation: ValidationMode) { this: protected def generateLambdaSignature(funcType: ParserNodeInfo): LambdaFunctionMetaData = { val genericTypeMethodMap: Map[String, List[String]] = Map() - // TODO: While handling the tuple return type we need to handle it here as well. - val (returnTypeStr, returnTypeInfo) = - getReturnType(funcType.json, genericTypeMethodMap).headOption - .getOrElse((Defines.voidTypeName, funcType)) + val returnTypes = getReturnType(funcType.json, genericTypeMethodMap) + val (returnTypeStr, returnTypeInfo) = returnTypes match { + case Seq() => (Defines.voidTypeName, funcType) + case Seq(single) => single + case multiple => (s"(${multiple.map(_._1).mkString(", ")})", multiple.head._2) + } val methodReturn = methodReturnNode(returnTypeInfo, returnTypeStr) val params = funcType.json(ParserKeys.Params)(ParserKeys.List) diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala index 200979440f54..8d501e77f547 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/AstForStatementsCreator.scala @@ -35,12 +35,17 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t case BranchStmt => Seq(astForBranchStatement(statement)) case BlockStmt => Seq(astForBlockStatement(statement, argIndex)) case CaseClause => astForCaseClause(statement) + case CommClause => astForCommClause(statement) case DeclStmt => astForNode(statement.json(ParserKeys.Decl)) + case DeferStmt => Seq(astForDeferStatement(statement)) case ExprStmt => astsForExpression(createParserNodeInfo(statement.json(ParserKeys.X))) case ForStmt => Seq(astForForStatement(statement)) + case GoStmt => Seq(astForGoStatement(statement)) case IfStmt => astForIfStatement(statement) case IncDecStmt => Seq(astForIncDecStatement(statement)) case RangeStmt => Seq(astForRangeStatement(statement)) + case SelectStmt => Seq(astForSelectStatement(statement)) + case SendStmt => Seq(astForSendStatement(statement)) case SwitchStmt => Seq(astForSwitchStatement(statement)) case TypeSwitchStmt => Seq(astForTypeSwitchStatement(statement)) case ReturnStmt => Seq(astForReturnStatement(statement)) @@ -51,7 +56,6 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t } private def astForReturnStatement(returnStmt: ParserNodeInfo): Ast = { - // TODO: Need to handle the tuple return node handling val cpgReturn = returnNode(returnStmt, returnStmt.code) val expast = returnStmt .json(ParserKeys.Results) @@ -273,8 +277,49 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t // To update the cache of parserNode with the labelled statement Try(createParserNodeInfo(branchStmt.json(ParserKeys.Label)(ParserKeys.Obj)(ParserKeys.Decl))) Ast(controlStructureNode(branchStmt, ControlStructureTypes.GOTO, branchStmt.code)) - case "fallthrough" => // TODO handling for FALLTHROUGH - Ast() + case "fallthrough" => + Ast(controlStructureNode(branchStmt, "fallthrough", branchStmt.code)) } } + + private def astForDeferStatement(deferStmt: ParserNodeInfo): Ast = { + val callAsts = astForNode(deferStmt.json(ParserKeys.Call)) + callAsts.headOption.getOrElse(Ast()) + } + + private def astForGoStatement(goStmt: ParserNodeInfo): Ast = { + val callAsts = astForNode(goStmt.json(ParserKeys.Call)) + callAsts.headOption.getOrElse(Ast()) + } + + private def astForSendStatement(sendStmt: ParserNodeInfo): Ast = { + val channelAst = astForNode(sendStmt.json(ParserKeys.Chan)) + val valueAst = astForNode(sendStmt.json(ParserKeys.Value)) + val arguments = channelAst ++ valueAst + val cNode = + callNode(sendStmt, sendStmt.code, Operator.send, Operator.send, DispatchTypes.STATIC_DISPATCH) + callAst(cNode, arguments) + } + + private def astForSelectStatement(selectStmt: ParserNodeInfo): Ast = { + val selectNode = controlStructureNode(selectStmt, ControlStructureTypes.SWITCH, s"select") + val stmtAsts = astsForStatement(createParserNodeInfo(selectStmt.json(ParserKeys.Body))) + controlStructureAst(selectNode, None, stmtAsts) + } + + private def astForCommClause(commClause: ParserNodeInfo): Seq[Ast] = { + val commAst = Try(commClause.json(ParserKeys.Comm)).toOption match { + case Some(commJson) => + val commParserNode = createParserNodeInfo(commJson) + val jumpTarget = jumpTargetNode(commClause, "case", s"case ${commParserNode.code}") + val commStmtAsts = astsForStatement(commParserNode).toList + Ast(jumpTarget) :: commStmtAsts + case _ => + val target = jumpTargetNode(commClause, "default", "default") + Seq(Ast(target)) + } + + val bodyAst = commClause.json(ParserKeys.Body).arr.map(createParserNodeInfo).flatMap(astsForStatement(_)).toList + commAst ++: bodyAst + } } diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala index aa2abec8212a..863f9f5484db 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/astcreation/CommonCacheBuilder.scala @@ -83,11 +83,13 @@ trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this: case _ => (s"$fullyQualifiedPackage.$name", fullyQualifiedPackage) } - // TODO: handle multiple return type or tuple (int, int) val genericTypeMethodMap = processTypeParams(funcDeclVal(ParserKeys.Type)) - val (returnTypeStr, _) = - getReturnType(funcDeclVal(ParserKeys.Type), genericTypeMethodMap).headOption - .getOrElse((Defines.voidTypeName, null)) + val returnTypes = getReturnType(funcDeclVal(ParserKeys.Type), genericTypeMethodMap) + val returnTypeStr = returnTypes match { + case Seq() => Defines.voidTypeName + case Seq(one) => one._1 + case multiple => s"(${multiple.map(_._1).mkString(", ")})" + } val params = funcDeclVal(ParserKeys.Type)(ParserKeys.Params)(ParserKeys.List) val signature = s"$methodFullname(${parameterSignature(params, genericTypeMethodMap)})${ diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala index 6aaa955028ac..31307082f703 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/parser/ParserAst.scala @@ -47,6 +47,10 @@ object ParserAst { object ForStmt extends BaseStmt object RangeStmt extends BaseStmt object BranchStmt extends BaseStmt + object DeferStmt extends BaseStmt + object GoStmt extends BaseStmt + object SelectStmt extends BaseStmt + object SendStmt extends BaseStmt object LabeledStmt extends BaseStmt sealed trait BasePrimitive extends ParserNode object BasicLit extends BasePrimitive @@ -59,6 +63,7 @@ object ParserAst { object FuncDecl extends ParserNode object ValueSpec extends ParserNode object CaseClause extends ParserNode + object CommClause extends ParserNode object InterfaceType extends ParserNode object FuncType extends ParserNode object Ellipsis extends ParserNode @@ -74,6 +79,9 @@ object ParserKeys { val Assign = "Assign" val Body = "Body" + val Call = "Call" + val Chan = "Chan" + val Comm = "Comm" val Cond = "Cond" val Decl = "Decl" val Decls = "Decls" diff --git a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/Constants.scala b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/Constants.scala index 81ccbffbea5a..3e78d5cbcd17 100644 --- a/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/Constants.scala +++ b/joern-cli/frontends/gosrc2cpg/src/main/scala/io/joern/gosrc2cpg/utils/Constants.scala @@ -8,4 +8,5 @@ object UtilityConstants { } object Operator { val unknown = ".unknown" + val send = ".send" } diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConcurrencyTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConcurrencyTests.scala new file mode 100644 index 000000000000..c6fa686a7e53 --- /dev/null +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/ConcurrencyTests.scala @@ -0,0 +1,103 @@ +package io.joern.go2cpg.passes.ast + +import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite +import io.shiftleft.codepropertygraph.generated.ControlStructureTypes +import io.shiftleft.codepropertygraph.generated.nodes.* +import io.shiftleft.semanticcpg.language.* + +class ConcurrencyTests extends GoCodeToCpgSuite { + + "AST creation for defer statements" should { + "create call node for deferred function call" in { + val cpg = code(""" + |package main + |func main() { + | f := openFile() + | defer f.Close() + |}""".stripMargin) + + val closeCall = cpg.call.name("Close").head + closeCall.code shouldBe "f.Close()" + closeCall.lineNumber shouldBe Some(5) + } + + "create call node for deferred function with arguments" in { + val cpg = code(""" + |package main + |import "fmt" + |func main() { + | defer fmt.Println("done") + |}""".stripMargin) + + val printlnCall = cpg.call.name("Println").head + printlnCall.code shouldBe "fmt.Println(\"done\")" + } + } + + "AST creation for go statements" should { + "create call node for goroutine function call" in { + val cpg = code(""" + |package main + |func handler(x int) {} + |func main() { + | go handler(42) + |}""".stripMargin) + + val handlerCalls = cpg.call.name("handler").l + handlerCalls.size should be >= 1 + handlerCalls.exists(_.code == "handler(42)") shouldBe true + } + + "create call node for goroutine with simple function" in { + val cpg = code(""" + |package main + |func worker(id int) {} + |func main() { + | go worker(1) + | go worker(2) + |}""".stripMargin) + + val workerCalls = cpg.call.name("worker").l + workerCalls.size shouldBe 2 + } + } + + "AST creation for send statements" should { + "create operator call for channel send" in { + val cpg = code(""" + |package main + |func main() { + | ch := make(chan int) + | ch <- 42 + |}""".stripMargin) + + val sendCall = cpg.call.name(".send").head + sendCall.code shouldBe "ch <- 42" + sendCall.argument.size shouldBe 2 + } + } + + "AST creation for select statements" should { + "create control structure for select with cases" in { + val cpg = code(""" + |package main + |func main() { + | ch1 := make(chan int) + | ch2 := make(chan int) + | select { + | case msg := <-ch1: + | x := msg + | case ch2 <- 1: + | y := 2 + | default: + | z := 3 + | } + |}""".stripMargin) + + val selectStmt = cpg.method.name("main").controlStructure.l + .filter(_.code == "select") + selectStmt.size shouldBe 1 + selectStmt.head.controlStructureType shouldBe ControlStructureTypes.SWITCH + } + } +} diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodTests.scala index d4fd3d89c349..ab76a37e2b8b 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/MethodTests.scala @@ -135,16 +135,15 @@ class MethodTests extends GoCodeToCpgSuite { |} |""".stripMargin) - "Be correct with method node properties" ignore { + "Be correct with method node properties" in { val List(x) = cpg.method.name("foo").l x.name shouldBe "foo" x.fullName shouldBe "main.foo" x.code should startWith("func foo() (fpkg.Sample,error){") - // TODO: Tuple handling needs to be done properly to return both the types. - x.signature shouldBe "main.foo()(joern.io/sample/fpkg.Sample,error)" + x.signature shouldBe "main.foo()(joern.io/sample/fpkg.Sample, error)" x.isExternal shouldBe false - x.order shouldBe 2 + x.order shouldBe 1 x.filename shouldBe "Test0.go" x.lineNumber shouldBe Option(4) x.lineNumberEnd shouldBe Option(6) @@ -153,8 +152,7 @@ class MethodTests extends GoCodeToCpgSuite { "Be correct with return node" in { cpg.method.name("foo").methodReturn.size shouldBe 1 val List(x) = cpg.method.name("foo").methodReturn.l - // TODO: Tuple handling needs to be done properly to return both the types. - x.typeFullName shouldBe "joern.io/sample/fpkg.Sample" + x.typeFullName shouldBe "(joern.io/sample/fpkg.Sample, error)" } } @@ -1607,5 +1605,23 @@ class MethodTests extends GoCodeToCpgSuite { // sem := make(chan int, concurrency) // As well as example of "map" // TODO: Add unit tests for lambda expression as a parameter - // TODO: Add unit test for tuple return + + "Function with tuple return type" should { + val cpg = code(""" + |package main + |func divide(a, b float64) (float64, error) { + | return a / b, nil + |} + |""".stripMargin) + + "Be correct with method return type" in { + val List(x) = cpg.method.name("divide").methodReturn.l + x.typeFullName shouldBe "(float64, error)" + } + + "Be correct with method signature" in { + val List(x) = cpg.method.name("divide").l + x.signature shouldBe "main.divide(float64, float64)(float64, error)" + } + } } diff --git a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/SwitchTests.scala b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/SwitchTests.scala index 60d314c641a7..7cf8309b256a 100644 --- a/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/SwitchTests.scala +++ b/joern-cli/frontends/gosrc2cpg/src/test/scala/io/joern/go2cpg/passes/ast/SwitchTests.scala @@ -148,8 +148,7 @@ class SwitchTests extends GoCodeToCpgSuite { ) } - // TODO Need to handle `fallthrough` statements - "ast creation for fallthrough" ignore { + "ast creation for fallthrough" should { "be correct" in { val cpg = code("""package main