Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ trait AstForFunctionsCreator(implicit withSchemaValidation: ValidationMode) { th

def astForFuncDecl(funcDecl: ParserNodeInfo): Seq[Ast] = {
val methodMetadata = processFuncDecl(funcDecl.json)
// TODO: handle multiple return type or tuple (int, int)
val (returnTypeStr, returnTypeInfo) =
getReturnType(funcDecl.json(ParserKeys.Type), methodMetadata.genericTypeMethodMap).headOption
.getOrElse((Defines.voidTypeName, funcDecl))
val returnTypes = getReturnType(funcDecl.json(ParserKeys.Type), methodMetadata.genericTypeMethodMap)
val (returnTypeStr, returnTypeInfo) = returnTypes match {
case Seq() => (Defines.voidTypeName, funcDecl)
case Seq(single) => single
case multiple => (s"(${multiple.map(_._1).mkString(", ")})", multiple.head._2)
}
val methodReturn = methodReturnNode(returnTypeInfo, returnTypeStr)
val methodNode_ = methodNode(
funcDecl,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ trait AstForLambdaCreator(implicit withSchemaValidation: ValidationMode) { this:

protected def generateLambdaSignature(funcType: ParserNodeInfo): LambdaFunctionMetaData = {
val genericTypeMethodMap: Map[String, List[String]] = Map()
// TODO: While handling the tuple return type we need to handle it here as well.
val (returnTypeStr, returnTypeInfo) =
getReturnType(funcType.json, genericTypeMethodMap).headOption
.getOrElse((Defines.voidTypeName, funcType))
val returnTypes = getReturnType(funcType.json, genericTypeMethodMap)
val (returnTypeStr, returnTypeInfo) = returnTypes match {
case Seq() => (Defines.voidTypeName, funcType)
case Seq(single) => single
case multiple => (s"(${multiple.map(_._1).mkString(", ")})", multiple.head._2)
}
val methodReturn = methodReturnNode(returnTypeInfo, returnTypeStr)

val params = funcType.json(ParserKeys.Params)(ParserKeys.List)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,17 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
case BranchStmt => Seq(astForBranchStatement(statement))
case BlockStmt => Seq(astForBlockStatement(statement, argIndex))
case CaseClause => astForCaseClause(statement)
case CommClause => astForCommClause(statement)
case DeclStmt => astForNode(statement.json(ParserKeys.Decl))
case DeferStmt => Seq(astForDeferStatement(statement))
case ExprStmt => astsForExpression(createParserNodeInfo(statement.json(ParserKeys.X)))
case ForStmt => Seq(astForForStatement(statement))
case GoStmt => Seq(astForGoStatement(statement))
case IfStmt => astForIfStatement(statement)
case IncDecStmt => Seq(astForIncDecStatement(statement))
case RangeStmt => Seq(astForRangeStatement(statement))
case SelectStmt => Seq(astForSelectStatement(statement))
case SendStmt => Seq(astForSendStatement(statement))
case SwitchStmt => Seq(astForSwitchStatement(statement))
case TypeSwitchStmt => Seq(astForTypeSwitchStatement(statement))
case ReturnStmt => Seq(astForReturnStatement(statement))
Expand All @@ -51,7 +56,6 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
}

private def astForReturnStatement(returnStmt: ParserNodeInfo): Ast = {
// TODO: Need to handle the tuple return node handling
val cpgReturn = returnNode(returnStmt, returnStmt.code)
val expast = returnStmt
.json(ParserKeys.Results)
Expand Down Expand Up @@ -273,8 +277,49 @@ trait AstForStatementsCreator(implicit withSchemaValidation: ValidationMode) { t
// To update the cache of parserNode with the labelled statement
Try(createParserNodeInfo(branchStmt.json(ParserKeys.Label)(ParserKeys.Obj)(ParserKeys.Decl)))
Ast(controlStructureNode(branchStmt, ControlStructureTypes.GOTO, branchStmt.code))
case "fallthrough" => // TODO handling for FALLTHROUGH
Ast()
case "fallthrough" =>
Ast(controlStructureNode(branchStmt, "fallthrough", branchStmt.code))
}
}

private def astForDeferStatement(deferStmt: ParserNodeInfo): Ast = {
val callAsts = astForNode(deferStmt.json(ParserKeys.Call))
callAsts.headOption.getOrElse(Ast())
}

private def astForGoStatement(goStmt: ParserNodeInfo): Ast = {
val callAsts = astForNode(goStmt.json(ParserKeys.Call))
callAsts.headOption.getOrElse(Ast())
}

private def astForSendStatement(sendStmt: ParserNodeInfo): Ast = {
val channelAst = astForNode(sendStmt.json(ParserKeys.Chan))
val valueAst = astForNode(sendStmt.json(ParserKeys.Value))
val arguments = channelAst ++ valueAst
val cNode =
callNode(sendStmt, sendStmt.code, Operator.send, Operator.send, DispatchTypes.STATIC_DISPATCH)
callAst(cNode, arguments)
}

private def astForSelectStatement(selectStmt: ParserNodeInfo): Ast = {
val selectNode = controlStructureNode(selectStmt, ControlStructureTypes.SWITCH, s"select")
val stmtAsts = astsForStatement(createParserNodeInfo(selectStmt.json(ParserKeys.Body)))
controlStructureAst(selectNode, None, stmtAsts)
}

private def astForCommClause(commClause: ParserNodeInfo): Seq[Ast] = {
val commAst = Try(commClause.json(ParserKeys.Comm)).toOption match {
case Some(commJson) =>
val commParserNode = createParserNodeInfo(commJson)
val jumpTarget = jumpTargetNode(commClause, "case", s"case ${commParserNode.code}")
val commStmtAsts = astsForStatement(commParserNode).toList
Ast(jumpTarget) :: commStmtAsts
case _ =>
val target = jumpTargetNode(commClause, "default", "default")
Seq(Ast(target))
}

val bodyAst = commClause.json(ParserKeys.Body).arr.map(createParserNodeInfo).flatMap(astsForStatement(_)).toList
commAst ++: bodyAst
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ trait CommonCacheBuilder(implicit withSchemaValidation: ValidationMode) { this:
case _ =>
(s"$fullyQualifiedPackage.$name", fullyQualifiedPackage)
}
// TODO: handle multiple return type or tuple (int, int)
val genericTypeMethodMap = processTypeParams(funcDeclVal(ParserKeys.Type))
val (returnTypeStr, _) =
getReturnType(funcDeclVal(ParserKeys.Type), genericTypeMethodMap).headOption
.getOrElse((Defines.voidTypeName, null))
val returnTypes = getReturnType(funcDeclVal(ParserKeys.Type), genericTypeMethodMap)
val returnTypeStr = returnTypes match {
case Seq() => Defines.voidTypeName
case Seq(one) => one._1
case multiple => s"(${multiple.map(_._1).mkString(", ")})"
}
val params = funcDeclVal(ParserKeys.Type)(ParserKeys.Params)(ParserKeys.List)
val signature =
s"$methodFullname(${parameterSignature(params, genericTypeMethodMap)})${
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ object ParserAst {
object ForStmt extends BaseStmt
object RangeStmt extends BaseStmt
object BranchStmt extends BaseStmt
object DeferStmt extends BaseStmt
object GoStmt extends BaseStmt
object SelectStmt extends BaseStmt
object SendStmt extends BaseStmt
object LabeledStmt extends BaseStmt
sealed trait BasePrimitive extends ParserNode
object BasicLit extends BasePrimitive
Expand All @@ -59,6 +63,7 @@ object ParserAst {
object FuncDecl extends ParserNode
object ValueSpec extends ParserNode
object CaseClause extends ParserNode
object CommClause extends ParserNode
object InterfaceType extends ParserNode
object FuncType extends ParserNode
object Ellipsis extends ParserNode
Expand All @@ -74,6 +79,9 @@ object ParserKeys {

val Assign = "Assign"
val Body = "Body"
val Call = "Call"
val Chan = "Chan"
val Comm = "Comm"
val Cond = "Cond"
val Decl = "Decl"
val Decls = "Decls"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ object UtilityConstants {
}
object Operator {
val unknown = "<operator>.unknown"
val send = "<operator>.send"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
package io.joern.go2cpg.passes.ast

import io.joern.go2cpg.testfixtures.GoCodeToCpgSuite
import io.shiftleft.codepropertygraph.generated.ControlStructureTypes
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.semanticcpg.language.*

class ConcurrencyTests extends GoCodeToCpgSuite {

"AST creation for defer statements" should {
"create call node for deferred function call" in {
val cpg = code("""
|package main
|func main() {
| f := openFile()
| defer f.Close()
|}""".stripMargin)

val closeCall = cpg.call.name("Close").head
closeCall.code shouldBe "f.Close()"
closeCall.lineNumber shouldBe Some(5)
}

"create call node for deferred function with arguments" in {
val cpg = code("""
|package main
|import "fmt"
|func main() {
| defer fmt.Println("done")
|}""".stripMargin)

val printlnCall = cpg.call.name("Println").head
printlnCall.code shouldBe "fmt.Println(\"done\")"
}
}

"AST creation for go statements" should {
"create call node for goroutine function call" in {
val cpg = code("""
|package main
|func handler(x int) {}
|func main() {
| go handler(42)
|}""".stripMargin)

val handlerCalls = cpg.call.name("handler").l
handlerCalls.size should be >= 1
handlerCalls.exists(_.code == "handler(42)") shouldBe true
}

"create call node for goroutine with simple function" in {
val cpg = code("""
|package main
|func worker(id int) {}
|func main() {
| go worker(1)
| go worker(2)
|}""".stripMargin)

val workerCalls = cpg.call.name("worker").l
workerCalls.size shouldBe 2
}
}

"AST creation for send statements" should {
"create operator call for channel send" in {
val cpg = code("""
|package main
|func main() {
| ch := make(chan int)
| ch <- 42
|}""".stripMargin)

val sendCall = cpg.call.name("<operator>.send").head
sendCall.code shouldBe "ch <- 42"
sendCall.argument.size shouldBe 2
}
}

"AST creation for select statements" should {
"create control structure for select with cases" in {
val cpg = code("""
|package main
|func main() {
| ch1 := make(chan int)
| ch2 := make(chan int)
| select {
| case msg := <-ch1:
| x := msg
| case ch2 <- 1:
| y := 2
| default:
| z := 3
| }
|}""".stripMargin)

val selectStmt = cpg.method.name("main").controlStructure.l
.filter(_.code == "select")
selectStmt.size shouldBe 1
selectStmt.head.controlStructureType shouldBe ControlStructureTypes.SWITCH
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,15 @@ class MethodTests extends GoCodeToCpgSuite {
|}
|""".stripMargin)

"Be correct with method node properties" ignore {
"Be correct with method node properties" in {
val List(x) = cpg.method.name("foo").l
x.name shouldBe "foo"
x.fullName shouldBe "main.foo"
x.code should startWith("func foo() (fpkg.Sample,error){")
// TODO: Tuple handling needs to be done properly to return both the types.
x.signature shouldBe "main.foo()(joern.io/sample/fpkg.Sample,error)"
x.signature shouldBe "main.foo()(joern.io/sample/fpkg.Sample, error)"
x.isExternal shouldBe false

x.order shouldBe 2
x.order shouldBe 1
x.filename shouldBe "Test0.go"
x.lineNumber shouldBe Option(4)
x.lineNumberEnd shouldBe Option(6)
Expand All @@ -153,8 +152,7 @@ class MethodTests extends GoCodeToCpgSuite {
"Be correct with return node" in {
cpg.method.name("foo").methodReturn.size shouldBe 1
val List(x) = cpg.method.name("foo").methodReturn.l
// TODO: Tuple handling needs to be done properly to return both the types.
x.typeFullName shouldBe "joern.io/sample/fpkg.Sample"
x.typeFullName shouldBe "(joern.io/sample/fpkg.Sample, error)"
}
}

Expand Down Expand Up @@ -1607,5 +1605,23 @@ class MethodTests extends GoCodeToCpgSuite {
// sem := make(chan int, concurrency)
// As well as example of "map"
// TODO: Add unit tests for lambda expression as a parameter
// TODO: Add unit test for tuple return

"Function with tuple return type" should {
val cpg = code("""
|package main
|func divide(a, b float64) (float64, error) {
| return a / b, nil
|}
|""".stripMargin)

"Be correct with method return type" in {
val List(x) = cpg.method.name("divide").methodReturn.l
x.typeFullName shouldBe "(float64, error)"
}

"Be correct with method signature" in {
val List(x) = cpg.method.name("divide").l
x.signature shouldBe "main.divide(float64, float64)(float64, error)"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,7 @@ class SwitchTests extends GoCodeToCpgSuite {
)
}

// TODO Need to handle `fallthrough` statements
"ast creation for fallthrough" ignore {
"ast creation for fallthrough" should {
"be correct" in {

val cpg = code("""package main
Expand Down