Skip to content

Commit

Permalink
Change operator extensions to use adhoc subtyping (#2750)
Browse files Browse the repository at this point in the history
  • Loading branch information
bbrehm authored Nov 28, 2023
1 parent d2c4c86 commit 4618ea7
Show file tree
Hide file tree
Showing 15 changed files with 66 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ class AstCreationPassTests extends AbstractPassTest {
val localZ = cpg.local.order(3)
localZ.name.l shouldBe List("z")

inside(cpg.method.name("method").ast.isCall.name(Operators.assignment).map(new OpNodes.Assignment(_)).l) {
inside(cpg.method.name("method").ast.isCall.name(Operators.assignment).cast[OpNodes.Assignment].l) {
case List(assignment) =>
assignment.target.code shouldBe "x"
assignment.source.start.isCall.name.l shouldBe List(Operators.addition)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package io.joern.ghidra2cpg.passes
import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.{Languages, nodes}
import io.shiftleft.passes.CpgPass
import overflowdb.BatchedUpdate

class MetaDataPass(filename: String, cpg: Cpg) extends CpgPass(cpg) {

override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = {
override def run(diffGraph: DiffGraphBuilder): Unit = {
diffGraph.addNode(
nodes
.NewTypeDecl()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class ExpressionsTests extends GoCodeToCpgSuite {

cpg.local.name.l shouldBe List("x", "y", "z")
val List(assignment) =
cpg.method.name("method").ast.isCall.name(Operators.assignment).map(new OpNodes.Assignment(_)).l
cpg.method.name("method").ast.isCall.name(Operators.assignment).cast[OpNodes.Assignment].l
assignment.target.code shouldBe "x"
assignment.source.start.isCall.name.l shouldBe List(Operators.addition)
val List(id1: Identifier, id2: Identifier) = assignment.source.astChildren.l: @unchecked
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,9 @@ private class RecoverForJavaScriptFile(cpg: Cpg, cu: File, builder: DiffGraphBui
case ::(fa: Call, ::(i: Identifier, _)) if fa.name == Operators.fieldAccess =>
symbolTable.append(
c,
visitIdentifierAssignedToFieldLoad(i, new FieldAccess(fa)).map(t => s"$t$pathSep$ConstructorMethodName")
visitIdentifierAssignedToFieldLoad(i, fa.asInstanceOf[FieldAccess]).map(t =>
s"$t$pathSep$ConstructorMethodName"
)
)
case _ => Set.empty
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph
!name.isBlank && name.charAt(0).isUpper

override def assignments: Iterator[Assignment] =
cu.ast.isCall.nameExact(Operators.assignment).map(new OpNodes.Assignment(_))
cu.ast.isCall.nameExact(Operators.assignment).cast[Assignment]

protected def unresolvedDynamicCalls: Iterator[Call] = cu.ast.isCall
.filter(_.dispatchType == DispatchTypes.DYNAMIC_DISPATCH)
Expand Down Expand Up @@ -110,7 +110,7 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph
case ::(head: Literal, Nil) if head.typeFullName != "ANY" =>
Set(head.typeFullName)
case ::(head: Call, Nil) if head.name == Operators.fieldAccess =>
val fieldAccess = new FieldAccess(head)
val fieldAccess = head.asInstanceOf[FieldAccess]
val (sym, ts) = getSymbolFromCall(fieldAccess)
val cpgTypes = cpg.typeDecl
.fullNameExact(ts.map(_.compUnitFullName).toSeq: _*)
Expand Down Expand Up @@ -188,7 +188,7 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph
}

override protected def getTypesFromCall(c: Call): Set[String] = c.name match {
case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c))))
case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(c.asInstanceOf[FieldAccess])))
case _ if symbolTable.contains(c) => symbolTable.get(c)
case Operators.indexAccess => getIndexAccessTypes(c)
case n => methodReturnValues(Seq(c.methodFullName))
Expand All @@ -197,7 +197,7 @@ private class RecoverForPhpFile(cpg: Cpg, cu: NamespaceBlock, builder: DiffGraph
override protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = {
def callName(x: Call) =
if (x.name == Operators.fieldAccess)
getFieldName(new FieldAccess(x))
getFieldName(x.asInstanceOf[FieldAccess])
else if (x.name == Operators.indexAccess)
indexAccessToCollectionVar(x)
.map(cv => s"${cv.identifier}[${cv.idx}]")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@ import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.PropertyNames
import io.shiftleft.codepropertygraph.generated.nodes.MethodParameterIn.PropertyDefaults
import io.shiftleft.passes.CpgPass
import overflowdb.BatchedUpdate
import io.shiftleft.semanticcpg.language._

/** Old CPGs use the `order` field to indicate the parameter index while newer CPGs use the `parameterIndex` field. This
* pass checks whether `parameterIndex` is not set, in which case the value of `order` is copied over.
*/
class ParameterIndexCompatPass(cpg: Cpg) extends CpgPass(cpg) {

override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = {
override def run(diffGraph: DiffGraphBuilder): Unit = {
cpg.parameter.foreach { param =>
if (param.index == PropertyDefaults.Index) {
diffGraph.setNodeProperty(param, PropertyNames.INDEX, param.order)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ abstract class XTypeRecoveryPass[CompilationUnitType <: AstNode](
.argumentOption(1)
.map {
case x: Call if x.name == Operators.fieldAccess =>
cpg.typeDecl.fullNameExact(FieldAccess(x).referencedMember.getKnownTypes.toSeq*)
cpg.typeDecl.fullNameExact(x.asInstanceOf[FieldAccess].referencedMember.getKnownTypes.toSeq*)
case x: Call if !x.name.startsWith("<operator>") =>
if (!x.typeFullName.matches(XTypeRecovery.unknownTypePattern.pattern.pattern()))
cpg.typeDecl.fullNameExact(x.typeFullName)
Expand Down Expand Up @@ -333,9 +333,9 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](

protected def assignments: Iterator[Assignment] = cu match {
case x: File =>
x.method.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map(new OpNodes.Assignment(_))
case x: Method => x.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).map(new OpNodes.Assignment(_))
case x => x.ast.isCall.nameExact(Operators.assignment).map(new OpNodes.Assignment(_))
x.method.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).cast[Assignment]
case x: Method => x.flatMap(_._callViaContainsOut).nameExact(Operators.assignment).cast[Assignment]
case x => x.ast.isCall.nameExact(Operators.assignment).cast[Assignment]
}

protected def returns: Iterator[Return] = cu match {
Expand Down Expand Up @@ -454,7 +454,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
protected def visitStatementsInBlock(b: Block, assignmentTarget: Option[Identifier] = None): Set[String] =
b.astChildren
.map {
case x: Call if x.name.startsWith(Operators.assignment) => visitAssignments(new Assignment(x))
case x: Call if x.name.startsWith(Operators.assignment) => visitAssignments(x.asInstanceOf[Assignment])
case x: Call if x.name.startsWith("<operator>") && assignmentTarget.isDefined =>
visitIdentifierAssignedToOperator(assignmentTarget.get, x, x.name)
case x: Identifier if symbolTable.contains(x) => symbolTable.get(x)
Expand Down Expand Up @@ -603,7 +603,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
protected def visitIdentifierAssignedToOperator(i: Identifier, c: Call, operation: String): Set[String] = {
operation match {
case Operators.alloc => visitIdentifierAssignedToConstructor(i, c)
case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, new FieldAccess(c))
case Operators.fieldAccess => visitIdentifierAssignedToFieldLoad(i, c.asInstanceOf[FieldAccess])
case Operators.indexAccess => visitIdentifierAssignedToIndexAccess(i, c)
case Operators.cast => visitIdentifierAssignedToCast(i, c)
case x => logger.debug(s"Unhandled operation $x (${c.code}) @ ${debugLocation(c)}"); Set.empty
Expand Down Expand Up @@ -688,7 +688,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
/** Given a call operation, will attempt to retrieve types from it.
*/
protected def getTypesFromCall(c: Call): Set[String] = c.name match {
case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(new FieldAccess(c))))
case Operators.fieldAccess => symbolTable.get(LocalVar(getFieldName(c.asInstanceOf[FieldAccess])))
case _ if symbolTable.contains(c) => methodReturnValues(symbolTable.get(c).toSeq)
case Operators.indexAccess => getIndexAccessTypes(c)
case n =>
Expand Down Expand Up @@ -734,7 +734,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
*/
protected def getSymbolFromCall(c: Call): (LocalKey, Set[FieldPath]) = c.name match {
case Operators.fieldAccess =>
val fa = new FieldAccess(c)
val fa = c.asInstanceOf[FieldAccess]
val fieldName = getFieldName(fa)
(LocalVar(fieldName), getFieldParents(fa).map(fp => FieldPath(fp, fieldName)))
case Operators.indexAccess => (indexAccessToCollectionVar(c).getOrElse(LocalVar(c.name)), Set.empty)
Expand Down Expand Up @@ -762,12 +762,12 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
case ::(i: Identifier, ::(f: FieldIdentifier, _)) if i.name.matches("(self|this)") => wrapName(f.canonicalName)
case ::(i: Identifier, ::(f: FieldIdentifier, _)) => wrapName(s"${i.name}$pathSep${f.canonicalName}")
case ::(c: Call, ::(f: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) =>
wrapName(getFieldName(new FieldAccess(c), suffix = f.canonicalName))
wrapName(getFieldName(c.asInstanceOf[FieldAccess], suffix = f.canonicalName))
case ::(_: Call, ::(f: FieldIdentifier, _)) if typesFromBaseCall.nonEmpty =>
// TODO: Handle this case better
wrapName(s"${typesFromBaseCall.head}$pathSep${f.canonicalName}")
case ::(f: FieldIdentifier, ::(c: Call, _)) if c.name.equals(Operators.fieldAccess) =>
wrapName(getFieldName(new FieldAccess(c), prefix = f.canonicalName))
wrapName(getFieldName(c.asInstanceOf[FieldAccess], prefix = f.canonicalName))
case ::(c: Call, ::(f: FieldIdentifier, _)) =>
// TODO: Handle this case better
val callCode = if (c.code.contains("(")) c.code.substring(c.code.indexOf("(")) else c.code
Expand Down Expand Up @@ -800,7 +800,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
Set.empty
}
} else if (c.name.equals(Operators.fieldAccess)) {
val fa = new FieldAccess(c)
val fa = c.asInstanceOf[FieldAccess]
val fieldName = getFieldName(fa)
associateTypes(LocalVar(fieldName), fa, getLiteralType(l))
} else {
Expand All @@ -819,7 +819,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
protected def indexAccessToCollectionVar(c: Call): Option[CollectionVar] = {
def callName(x: Call) =
if (x.name.equals(Operators.fieldAccess))
getFieldName(new FieldAccess(x))
getFieldName(x.asInstanceOf[FieldAccess])
else if (x.name.equals(Operators.indexAccess))
indexAccessToCollectionVar(x)
.map(cv => s"${cv.identifier}[${cv.idx}]")
Expand Down Expand Up @@ -854,7 +854,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
val dummyTypes = Set(s"$fieldName$pathSep${XTypeRecovery.DummyReturnType}")
associateInterproceduralTypes(i, base, fi, fieldName, dummyTypes)
case ::(c: Call, ::(fi: FieldIdentifier, _)) if c.name.equals(Operators.fieldAccess) =>
val baseName = getFieldName(new FieldAccess(c))
val baseName = getFieldName(c.asInstanceOf[FieldAccess])
// Build type regardless of length
// TODO: This is more prone to giving dummy values as it does not do global look-ups
// but this is okay for now
Expand Down Expand Up @@ -918,7 +918,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
def extractTypes(xs: List[CfgNode]): Set[String] = xs match {
case ::(head: Literal, Nil) => getLiteralType(head)
case ::(head: Call, Nil) if head.name == Operators.fieldAccess =>
val fieldAccess = new FieldAccess(head)
val fieldAccess = head.asInstanceOf[FieldAccess]
val (sym, ts) = getSymbolFromCall(fieldAccess)
val cpgTypes = cpg.typeDecl
.fullNameExact(ts.map(_.compUnitFullName).toSeq: _*)
Expand Down Expand Up @@ -1002,7 +1002,7 @@ abstract class RecoverForXCompilationUnit[CompilationUnitType <: AstNode](
// Case 3: 'i' is the receiver for a field access on member 'f'
case (Some(fieldAccess: Call), ::(i: Identifier, ::(f: FieldIdentifier, _)))
if fieldAccess.name == Operators.fieldAccess =>
setTypeForFieldAccess(new FieldAccess(fieldAccess), i, f)
setTypeForFieldAccess(fieldAccess.asInstanceOf[FieldAccess], i, f)
case _ =>
}
// Handle the node itself
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@ import io.shiftleft.codepropertygraph.Cpg
import io.shiftleft.codepropertygraph.generated.Properties
import io.shiftleft.passes.CpgPass
import io.shiftleft.semanticcpg.language._
import overflowdb.BatchedUpdate

object Overlays {

def appendOverlayName(cpg: Cpg, overlayName: String): Unit = {
new CpgPass(cpg) {
override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = {
override def run(diffGraph: DiffGraphBuilder): Unit = {
cpg.metaData.headOption match {
case Some(metaData) =>
val newValue = metaData.overlays :+ overlayName
Expand All @@ -24,7 +23,7 @@ object Overlays {

def removeLastOverlayName(cpg: Cpg): Unit = {
new CpgPass(cpg) {
override def run(diffGraph: BatchedUpdate.DiffGraphBuilder): Unit = {
override def run(diffGraph: DiffGraphBuilder): Unit = {
cpg.metaData.headOption match {
case Some(metaData) =>
val newValue = metaData.overlays.dropRight(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@ package io.shiftleft.semanticcpg.language.modulevariable
import io.shiftleft.codepropertygraph.generated.nodes.*
import io.shiftleft.codepropertygraph.generated.{Cpg, Operators}
import io.shiftleft.semanticcpg.language.*
import io.shiftleft.semanticcpg.language.modulevariable.OpNodes.ModuleVariable
import io.shiftleft.semanticcpg.language.modulevariable.OpNodes
import io.shiftleft.semanticcpg.language.operatorextension.OpNodes.FieldAccess
import overflowdb.traversal.help.Doc

class ModuleVariableAsLocalTraversal(traversal: Iterator[Local]) extends AnyVal {

@Doc(info = "Locals representing module variables")
def moduleVariables: Iterator[ModuleVariable] = {
traversal.filter(_.isModuleVariable).map(new OpNodes.ModuleVariable(_))
def moduleVariables: Iterator[OpNodes.ModuleVariable] = {
traversal.filter(_.isModuleVariable).cast[OpNodes.ModuleVariable]
}

}

class ModuleVariableAsIdentifierTraversal(traversal: Iterator[Identifier]) extends AnyVal {

@Doc(info = "Identifiers representing module variables")
def moduleVariables: Iterator[ModuleVariable] = {
def moduleVariables: Iterator[OpNodes.ModuleVariable] = {
traversal.flatMap(_._localViaRefOut).moduleVariables
}

Expand All @@ -28,7 +28,7 @@ class ModuleVariableAsIdentifierTraversal(traversal: Iterator[Identifier]) exten
class ModuleVariableAsFieldIdentifierTraversal(traversal: Iterator[FieldIdentifier]) extends AnyVal {

@Doc(info = "Field identifiers representing module variables")
def moduleVariables: Iterator[ModuleVariable] = {
def moduleVariables: Iterator[OpNodes.ModuleVariable] = {
traversal.flatMap { fieldIdentifier =>
Cpg(fieldIdentifier.graph()).method
.fullNameExact(fieldIdentifier.inFieldAccess.argument(1).isIdentifier.typeFullName.toSeq*)
Expand All @@ -43,7 +43,7 @@ class ModuleVariableAsFieldIdentifierTraversal(traversal: Iterator[FieldIdentifi
class ModuleVariableAsMemberTraversal(traversal: Iterator[Member]) extends AnyVal {

@Doc(info = "Members representing module variables")
def moduleVariables: Iterator[ModuleVariable] = {
def moduleVariables: Iterator[OpNodes.ModuleVariable] = {
val members = traversal.toList
lazy val memberNames = members.name.toSeq
members.headOption.map(m => Cpg(m.graph())) match
Expand All @@ -60,11 +60,11 @@ class ModuleVariableAsMemberTraversal(traversal: Iterator[Member]) extends AnyVa
class ModuleVariableAsExpressionTraversal(traversal: Iterator[Expression]) extends AnyVal {

@Doc(info = "Expression nodes representing module variables")
def moduleVariables: Iterator[ModuleVariable] = {
def moduleVariables: Iterator[OpNodes.ModuleVariable] = {
traversal.flatMap {
case x: Identifier => x.start.moduleVariables
case x: FieldIdentifier => x.start.moduleVariables
case x: Call if x.name == Operators.fieldAccess => new FieldAccess(x).fieldIdentifier.moduleVariables
case x: Call if x.name == Operators.fieldAccess => x.asInstanceOf[FieldAccess].fieldIdentifier.moduleVariables
case _ => Iterator.empty
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.shiftleft.semanticcpg.language.modulevariable

import io.shiftleft.codepropertygraph.generated.nodes.{Block, Local, Member}
import io.shiftleft.codepropertygraph.generated.nodes.{Block, Local, Member, StaticType}

trait ModuleVariableT
object OpNodes {

/** Represents a module-level global variable. This kind of node behaves like both a local variable and a field access
* and is common in languages such as Python/JavaScript.
*/
class ModuleVariable(node: Local) extends Local(node.graph(), node.id)

type ModuleVariable = Local with StaticType[ModuleVariableT]

}
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package io.shiftleft.semanticcpg.language.operatorextension

import io.shiftleft.codepropertygraph.generated.nodes.Expression
import io.shiftleft.codepropertygraph.generated.nodes
import io.shiftleft.semanticcpg.language.*
import overflowdb.traversal.help
import overflowdb.traversal.help.Doc

@help.Traversal(elementType = classOf[OpNodes.Assignment])
@help.Traversal(elementType = classOf[nodes.Call])
class AssignmentTraversal(val traversal: Iterator[OpNodes.Assignment]) extends AnyVal {

@Doc(info = "Left-hand sides of assignments")
def target: Iterator[Expression] = traversal.map(_.target)
def target: Iterator[nodes.Expression] = traversal.map(_.target)

@Doc(info = "Right-hand sides of assignments")
def source: Iterator[Expression] = traversal.map(_.source)
def source: Iterator[nodes.Expression] = traversal.map(_.source)
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@ class NodeTypeStarters(cpg: Cpg) {
@Doc(info = "All assignments, including shorthand assignments that perform arithmetic (e.g., '+=')")
def assignment: Iterator[OpNodes.Assignment] =
callsWithNameIn(allAssignmentTypes)
.map(new OpNodes.Assignment(_))
.cast[OpNodes.Assignment]

@Doc(info = "All arithmetic operations, including shorthand assignments that perform arithmetic (e.g., '+=')")
def arithmetic: Iterator[OpNodes.Arithmetic] =
callsWithNameIn(allArithmeticTypes)
.map(new OpNodes.Arithmetic(_))
.cast[OpNodes.Arithmetic]

@Doc(info = "All array accesses")
def arrayAccess: Iterator[OpNodes.ArrayAccess] =
callsWithNameIn(allArrayAccessTypes)
.map(new OpNodes.ArrayAccess(_))
.cast[OpNodes.ArrayAccess]

@Doc(info = "Field accesses, both direct and indirect")
def fieldAccess: Iterator[OpNodes.FieldAccess] =
callsWithNameIn(allFieldAccessTypes)
.map(new OpNodes.FieldAccess(_))
.cast[OpNodes.FieldAccess]

private def callsWithNameIn(set: Set[String]) =
cpg.call.filter(x => set.contains(x.name))
Expand Down
Loading

0 comments on commit 4618ea7

Please sign in to comment.