diff --git a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala index 595d30374e1a..c363a55bd1f4 100644 --- a/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala +++ b/joern-cli/frontends/javasrc2cpg/src/main/scala/io/joern/javasrc2cpg/passes/TypeInferencePass.scala @@ -22,7 +22,6 @@ class TypeInferencePass(cpg: Cpg) extends ConcurrentWriterCpgPass[Call](cpg) { .filterNot(_.fullName.startsWith(Defines.UnresolvedNamespace)) .filterNot(_.signature.startsWith(Defines.UnresolvedSignature)) .groupBy(_.name) - private case class NameParts(typeDecl: Option[String], signature: String) override def generateParts(): Array[Call] = { @@ -32,7 +31,12 @@ class TypeInferencePass(cpg: Cpg) extends ConcurrentWriterCpgPass[Call](cpg) { .toArray } - private def isMatchingMethod(method: Method, call: Call, callNameParts: NameParts): Boolean = { + private def isMatchingMethod( + method: Method, + call: Call, + callNameParts: NameParts, + ignoreArgTypes: Boolean = false + ): Boolean = { // An erroneous `this` argument is added for unresolved calls to static methods. val argSizeMod = if (method.modifier.modifierType.iterator.contains(ModifierTypes.STATIC)) 1 else 0 lazy val methodNameParts = getNameParts(method.name, method.fullName) @@ -44,7 +48,8 @@ class TypeInferencePass(cpg: Cpg) extends ConcurrentWriterCpgPass[Call](cpg) { lazy val typeDeclMatches = (callNameParts.typeDecl == methodNameParts.typeDecl) - parameterSizesMatch && argTypesMatch && typeDeclMatches + if ignoreArgTypes then parameterSizesMatch && typeDeclMatches + else parameterSizesMatch && argTypesMatch && typeDeclMatches } /** Check if argument types match by comparing exact full names. An argument type of `ANY` always matches. @@ -87,19 +92,41 @@ class TypeInferencePass(cpg: Cpg) extends ConcurrentWriterCpgPass[Call](cpg) { cache.get(callKey).toScala.getOrElse { val callNameParts = getNameParts(call.name, call.methodFullName) resolvedMethodIndex.get(call.name).flatMap { candidateMethods => - val candidateMethodsIter = candidateMethods.iterator - val uniqueMatchingMethod = - candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)).flatMap { method => - val otherMatchingMethod = candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts)) - // Only return a resulting method if exactly one matching method is found. - Option.when(otherMatchingMethod.isEmpty)(method) - } + val uniqueMatchingMethod = retreiveMatchingMethod(candidateMethods, call, callNameParts) match { + case Some(method) => Some(method) + case None => retreiveMatchingMethod(candidateMethods, call, callNameParts, ignoreArgTypes = true) + } cache.put(callKey, uniqueMatchingMethod) uniqueMatchingMethod } } } + /** Return a method only if there exists a one to one mapping of call to method node + * @param candidateMethods + * @param call + * @param callNameParts + * @param ignoreArgTypes + * @return + */ + private def retreiveMatchingMethod( + candidateMethods: List[Method], + call: Call, + callNameParts: NameParts, + ignoreArgTypes: Boolean = false + ): Option[Method] = { + val candidateMethodsIter = candidateMethods.iterator + val uniqueMatchingMethod = + candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts, ignoreArgTypes = ignoreArgTypes)).flatMap { + method => + val otherMatchingMethod = + candidateMethodsIter.find(isMatchingMethod(_, call, callNameParts, ignoreArgTypes = ignoreArgTypes)) + // Only return a resulting method if exactly one matching method is found. + Option.when(otherMatchingMethod.isEmpty)(method) + } + uniqueMatchingMethod + } + override def runOnPart(diffGraph: DiffGraphBuilder, call: Call): Unit = { getReplacementMethod(call).foreach { replacementMethod => diffGraph.setNodeProperty(call, PropertyNames.MethodFullName, replacementMethod.fullName) diff --git a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala index 5f2562cbd982..7d8853695566 100644 --- a/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala +++ b/joern-cli/frontends/javasrc2cpg/src/test/scala/io/joern/javasrc2cpg/querying/TypeInferenceTests.scala @@ -550,3 +550,42 @@ class TypeInferenceTests extends JavaSrcCode2CpgFixture { } } + +class TypeInferenceByArgSizeTests extends JavaSrcCode2CpgFixture { + "Calls having same namespace as a method" should { + val cpg = code(""" + |import com.myorg.client.Client; + |import com.myorg.config.RestConfigBase; + |import com.myorg.config.RestConfig; + |import com.myorg.interceptor.Interceptor; + | + |public class Sample { + | + | public static Client createClient(RestConfigBase config) { + | return new MyClient(config); + | } + | + | public static Client createClient(RestConfigBase config, Interceptor interceptor) { + | return new MyClient(config, interceptor); + | } + | + | Client getClient() { + | return Sample.createClient(new RestConfig("someUrl", "othervalue")); + | } + | + |} + |""".stripMargin) + + "have resolved methodFullName if argument and parameter size matches" in { + val createClientCalls = cpg.call("createClient").l + + createClientCalls.size shouldBe 1 + createClientCalls.methodFullName.l shouldBe List( + "Sample.createClient:com.myorg.client.Client(com.myorg.config.RestConfigBase)" + ) + createClientCalls.callee.fullName.l shouldBe List( + "Sample.createClient:com.myorg.client.Client(com.myorg.config.RestConfigBase)" + ) + } + } +}