Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import io.joern.pysrc2cpg.memop.*
import io.joern.pysrc2cpg.memop.MemoryOperation.{Del, Load, Store}
import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants.builtinPrefix
import io.joern.pythonparser.{AstPrinter, ast}
import io.joern.pythonparser.ast.{Arguments, MatchAs, iast, iexpr, istmt}
import io.joern.pythonparser.ast.{Arguments, MatchAs, MatchClass, MatchMapping, MatchOr, MatchSequence, MatchSingleton, MatchStar, MatchValue, iast, iexpr, ipattern, istmt}
import io.joern.x2cpg.frontendspecific.pysrc2cpg.Constants
import io.joern.x2cpg.{AstCreatorBase, ValidationMode}
import io.shiftleft.codepropertygraph.generated.*
Expand Down Expand Up @@ -1243,37 +1243,235 @@ class PythonAstVisitor(
createBlock(blockStmts, lineAndCol)
}

// TODO add case pattern and guard statements to not just as string in the JUMP_TARGET to the CPG
// but rather as proper AST constructs.
def convert(matchStmt: ast.Match): NewNode = {
val lineAndCol = lineAndColOf(matchStmt)
val controlStructureNode =
nodeBuilder.controlStructureNode("match ... : ...", ControlStructureTypes.MATCH, lineAndColOf(matchStmt))
nodeBuilder.controlStructureNode("match ... : ...", ControlStructureTypes.MATCH, lineAndCol)

val matchSubject = convert(matchStmt.subject)
// For simple Name subjects, reference the name directly in pattern assignments.
// For complex expressions, create a temp variable so it's evaluated once and referenceable.
val (subjectRefName, matchSubjectNode, prefixNodes) = matchStmt.subject match {
case ast.Name(id, _) =>
(id, convert(matchStmt.subject), Seq.empty[nodes.NewNode])
case _ =>
val tmpName = getUnusedName()
val subjectExpr = convert(matchStmt.subject)
val tmpAssign = createAssignmentToIdentifier(tmpName, subjectExpr, lineAndCol)
val tmpRef = createIdentifierNode(tmpName, Load, lineAndCol)
(tmpName, tmpRef, Seq(tmpAssign))
}

val caseBlocks = matchStmt.cases.flatMap { caseStmt =>
val jumpTargetCode =
caseStmt.pattern match {
case MatchAs(None, _, _) if caseStmt.guard.isEmpty =>
// TODO For the moment we have to use "default" because otherwise the CfgCreator does not detect
// the jump target as the default case.
// Use "default" because the CfgCreator checks for this to detect the default case.
"default"
case pattern =>
val printer = new AstPrinter("")
"case " + printer.print(pattern) + caseStmt.guard.map(g => " if " + printer.print(g)).getOrElse("")
}
val jumpTarget = nodeBuilder.jumpNode(jumpTargetCode)
val bodyNodes = caseStmt.body.map(convert)
jumpTarget :: createBlock(bodyNodes, lineAndColOf(caseStmt.pattern)) :: Nil
val jumpTarget = nodeBuilder.jumpNode(jumpTargetCode)
val patternAssignments = lowerMatchPatternBindings(caseStmt.pattern, subjectRefName, lineAndColOf(caseStmt.pattern))
val bodyNodes = caseStmt.body.map(convert)
jumpTarget :: createBlock(patternAssignments ++ bodyNodes, lineAndColOf(caseStmt.pattern)) :: Nil
}

val switchBodyBlock = createBlock(caseBlocks, lineAndColOf(matchStmt))
val switchBodyBlock = createBlock(caseBlocks, lineAndCol)

edgeBuilder.conditionEdge(matchSubject, controlStructureNode)
addAstChildNodes(controlStructureNode, 1, matchSubject)
edgeBuilder.conditionEdge(matchSubjectNode, controlStructureNode)
addAstChildNodes(controlStructureNode, 1, matchSubjectNode)
addAstChildNodes(controlStructureNode, 2, switchBodyBlock)

controlStructureNode
if (prefixNodes.nonEmpty) {
createBlock(prefixNodes :+ controlStructureNode, lineAndCol)
} else {
controlStructureNode
}
}

/** Lower match pattern bindings into assignment nodes that extract values from the subject.
*
* For example, `case [a, b]:` matching against subject `x` produces: `a = x[0]`, `b = x[1]`
*
* This reuses the same index access and assignment primitives as tuple unpacking.
*/
private def lowerMatchPatternBindings(
pattern: ipattern,
subjectRefName: String,
lineAndCol: LineAndColumn
): Seq[nodes.NewNode] = {
pattern match {
case MatchAs(None, Some(name), _) =>
// Catch-all with binding: `case x:` → `x = subject`
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
Seq(createAssignmentToIdentifier(name, subjectRef, lineAndCol))

case MatchAs(Some(inner), Some(name), _) =>
// Pattern with alias: `case [a, b] as whole:` → recurse into inner + `whole = subject`
val innerBindings = lowerMatchPatternBindings(inner, subjectRefName, lineAndCol)
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
innerBindings :+ createAssignmentToIdentifier(name, subjectRef, lineAndCol)

case MatchAs(_, None, _) =>
// Wildcard `case _:` or unnamed pattern — no bindings
Seq.empty

case MatchSequence(patterns, _) =>
patterns.zipWithIndex.flatMap { case (elemPattern, index) =>
lowerSequenceElementBinding(elemPattern, subjectRefName, index, lineAndCol)
}.toSeq

case MatchStar(Some(name), _) =>
// Star capture: `*rest` — bind to subject (simplified)
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
Seq(createAssignmentToIdentifier(name, subjectRef, lineAndCol))

case MatchMapping(keys, patterns, rest, _) =>
val keyBindings = keys.zip(patterns).flatMap { case (key, valuePattern) =>
lowerMappingElementBinding(valuePattern, subjectRefName, key, lineAndCol)
}.toSeq
val restBinding = rest.map { restName =>
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
createAssignmentToIdentifier(restName, subjectRef, lineAndCol)
}.toSeq
keyBindings ++ restBinding

case MatchClass(_, patterns, kwdAttrs, kwdPatterns, _) =>
val positionalBindings = patterns.zipWithIndex.flatMap { case (elemPattern, index) =>
lowerSequenceElementBinding(elemPattern, subjectRefName, index, lineAndCol)
}.toSeq
val keywordBindings = kwdAttrs.zip(kwdPatterns).flatMap { case (attrName, attrPattern) =>
lowerAttributeBinding(attrPattern, subjectRefName, attrName, lineAndCol)
}.toSeq
positionalBindings ++ keywordBindings

case MatchOr(patterns, _) =>
// All alternatives must bind the same names in Python. Process the first.
patterns.headOption.map(lowerMatchPatternBindings(_, subjectRefName, lineAndCol)).getOrElse(Seq.empty)

case _: MatchValue | _: MatchSingleton =>
// Literal/singleton patterns have no variable bindings
Seq.empty

case _ => Seq.empty
}
}

/** Lower a sequence element pattern at a given index. For named bindings, creates `name = subject[index]`. For nested
* patterns like `[a, [b, c]]`, creates a temp variable for the nested subject.
*/
private def lowerSequenceElementBinding(
elemPattern: ipattern,
subjectRefName: String,
index: Int,
lineAndCol: LineAndColumn
): Seq[nodes.NewNode] = {
elemPattern match {
case MatchAs(None, Some(name), _) =>
// Direct binding: `a` in `case [a, b]:` → `a = subject[index]`
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val indexNode = nodeBuilder.intLiteralNode(index.toString, lineAndCol)
val indexAccess = createIndexAccess(subjectRef, indexNode, lineAndCol)
Seq(createAssignmentToIdentifier(name, indexAccess, lineAndCol))

case MatchStar(Some(name), _) =>
// Star capture in sequence: `*rest` in `case [a, *rest]:` → `rest = subject[index]`
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val indexNode = nodeBuilder.intLiteralNode(index.toString, lineAndCol)
val indexAccess = createIndexAccess(subjectRef, indexNode, lineAndCol)
Seq(createAssignmentToIdentifier(name, indexAccess, lineAndCol))

case nested @ (_: MatchSequence | _: MatchMapping | _: MatchClass | _: MatchOr) =>
// Nested pattern: create a temp variable for `subject[index]`, then recurse
val tmpName = getUnusedName()
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val indexNode = nodeBuilder.intLiteralNode(index.toString, lineAndCol)
val indexAccess = createIndexAccess(subjectRef, indexNode, lineAndCol)
val tmpAssign = createAssignmentToIdentifier(tmpName, indexAccess, lineAndCol)
tmpAssign +: lowerMatchPatternBindings(nested, tmpName, lineAndCol)

case MatchAs(Some(inner), Some(name), _) =>
// Aliased nested pattern: `[a, b] as x` at position
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val indexNode = nodeBuilder.intLiteralNode(index.toString, lineAndCol)
val indexAccess = createIndexAccess(subjectRef, indexNode, lineAndCol)
val tmpName = getUnusedName()
val tmpAssign = createAssignmentToIdentifier(tmpName, indexAccess, lineAndCol)
val nameAssign = createAssignmentToIdentifier(name, createIdentifierNode(tmpName, Load, lineAndCol), lineAndCol)
(tmpAssign +: lowerMatchPatternBindings(inner, tmpName, lineAndCol)) :+ nameAssign

case _ =>
// MatchValue, MatchSingleton, MatchAs(_, None, _) — no bindings
Seq.empty
}
}

/** Lower a mapping pattern element: `name = subject[key]` */
private def lowerMappingElementBinding(
valuePattern: ipattern,
subjectRefName: String,
key: iexpr,
lineAndCol: LineAndColumn
): Seq[nodes.NewNode] = {
valuePattern match {
case MatchAs(None, Some(name), _) =>
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val keyNode = convert(key)
val indexAccess = createIndexAccess(subjectRef, keyNode, lineAndCol)
Seq(createAssignmentToIdentifier(name, indexAccess, lineAndCol))

case nested =>
// For nested patterns, create temp for subject[key] and recurse
val bindings = extractPatternBindingNames(nested)
if (bindings.nonEmpty) {
val tmpName = getUnusedName()
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val keyNode = convert(key)
val indexAccess = createIndexAccess(subjectRef, keyNode, lineAndCol)
val tmpAssign = createAssignmentToIdentifier(tmpName, indexAccess, lineAndCol)
tmpAssign +: lowerMatchPatternBindings(nested, tmpName, lineAndCol)
} else Seq.empty
}
}

/** Lower a class attribute pattern: `name = subject.attr` */
private def lowerAttributeBinding(
attrPattern: ipattern,
subjectRefName: String,
attrName: String,
lineAndCol: LineAndColumn
): Seq[nodes.NewNode] = {
attrPattern match {
case MatchAs(None, Some(name), _) =>
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val fieldAccess = createFieldAccess(subjectRef, attrName, lineAndCol)
Seq(createAssignmentToIdentifier(name, fieldAccess, lineAndCol))

case nested =>
val bindings = extractPatternBindingNames(nested)
if (bindings.nonEmpty) {
val tmpName = getUnusedName()
val subjectRef = createIdentifierNode(subjectRefName, Load, lineAndCol)
val fieldAccess = createFieldAccess(subjectRef, attrName, lineAndCol)
val tmpAssign = createAssignmentToIdentifier(tmpName, fieldAccess, lineAndCol)
tmpAssign +: lowerMatchPatternBindings(nested, tmpName, lineAndCol)
} else Seq.empty
}
}

/** Check if a pattern contains any variable bindings (used to avoid creating unnecessary temps). */
private def extractPatternBindingNames(pattern: ipattern): Seq[String] = {
pattern match {
case MatchAs(inner, name, _) => name.toSeq ++ inner.toSeq.flatMap(extractPatternBindingNames)
case MatchSequence(patterns, _) => patterns.flatMap(extractPatternBindingNames).toSeq
case MatchStar(name, _) => name.toSeq
case MatchMapping(_, patterns, rest, _) => patterns.flatMap(extractPatternBindingNames).toSeq ++ rest.toSeq
case MatchClass(_, patterns, _, kwdPatterns, _) =>
patterns.flatMap(extractPatternBindingNames).toSeq ++ kwdPatterns.flatMap(extractPatternBindingNames).toSeq
case MatchOr(patterns, _) => patterns.headOption.toSeq.flatMap(extractPatternBindingNames)
case _ => Seq.empty
}
}

def convert(raise: ast.Raise): NewNode = {
Expand Down
Loading