Skip to content

Commit

Permalink
Keep local variable names feature (#215)
Browse files Browse the repository at this point in the history
  • Loading branch information
sergeypospelov authored Feb 8, 2024
1 parent 1a0fb7e commit 389ea0f
Show file tree
Hide file tree
Showing 18 changed files with 263 additions and 75 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,9 @@ abstract class BaseAnalysisTest : BaseTest() {
// TODO: think about better assertions here
assertEquals(expectedLocations.size, sinks.size)
expectedLocations.forEach { expected ->
assertTrue(sinks.any { it.contains(expected) })
assertTrue(sinks.any { it.contains(expected) }) {
"$expected unmatched in:\n${sinks.joinToString("\n")}"
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class NpeAnalysisTest : BaseAnalysisTest() {

@Test
fun `analyze simple NPE`() {
testOneMethod<NpeExamples>("npeOnLength", listOf("%3 = %0.length()"))
testOneMethod<NpeExamples>("npeOnLength", listOf("%3 = x.length()"))
}

@Test
Expand All @@ -72,7 +72,7 @@ class NpeAnalysisTest : BaseAnalysisTest() {
fun `analyze NPE after fun with two exits`() {
testOneMethod<NpeExamples>(
"npeAfterTwoExits",
listOf("%4 = %0.length()", "%5 = %1.length()")
listOf("%4 = x.length()", "%5 = y.length()")
)
}

Expand All @@ -85,15 +85,15 @@ class NpeAnalysisTest : BaseAnalysisTest() {
fun `consecutive NPEs handled properly`() {
testOneMethod<NpeExamples>(
"consecutiveNPEs",
listOf("%2 = arg$0.length()", "%4 = arg$0.length()")
listOf("a = x.length()", "c = x.length()")
)
}

@Test
fun `npe on virtual call when possible`() {
testOneMethod<NpeExamples>(
"possibleNPEOnVirtualCall",
listOf("%0 = arg\$0.length()")
listOf("%0 = x.length()")
)
}

Expand All @@ -107,7 +107,7 @@ class NpeAnalysisTest : BaseAnalysisTest() {

@Test
fun `basic test for NPE on fields`() {
testOneMethod<NpeExamples>("simpleNPEOnField", listOf("%8 = %6.length()"))
testOneMethod<NpeExamples>("simpleNPEOnField", listOf("len2 = second.length()"))
}

@Disabled("Flowdroid architecture not supported for async ifds yet")
Expand Down Expand Up @@ -152,7 +152,7 @@ class NpeAnalysisTest : BaseAnalysisTest() {

@Test
fun `NPE on uninitialized array element dereferencing`() {
testOneMethod<NpeExamples>("simpleArrayNPE", listOf("%5 = %4.length()"))
testOneMethod<NpeExamples>("simpleArrayNPE", listOf("b = %4.length()"))
}

@Test
Expand All @@ -174,7 +174,7 @@ class NpeAnalysisTest : BaseAnalysisTest() {

@Test
fun `dereferencing field of null object`() {
testOneMethod<NpeExamples>("npeOnFieldDeref", listOf("%1 = %0.field"))
testOneMethod<NpeExamples>("npeOnFieldDeref", listOf("s = a.field"))
}

@Test
Expand Down
45 changes: 43 additions & 2 deletions jacodb-api/src/main/kotlin/org/jacodb/api/cfg/JcInst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -811,8 +811,10 @@ interface JcLocal : JcSimpleValue {
val name: String
}

/**
* @param name isn't considered in `equals` and `hashcode`
*/
data class JcArgument(val index: Int, override val name: String, override val type: JcType) : JcLocal {

companion object {
@JvmStatic
fun of(index: Int, name: String?, type: JcType): JcArgument {
Expand All @@ -825,14 +827,53 @@ data class JcArgument(val index: Int, override val name: String, override val ty
override fun <T> accept(visitor: JcExprVisitor<T>): T {
return visitor.visitJcArgument(this)
}

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as JcArgument

if (index != other.index) return false
if (type != other.type) return false

return true
}

override fun hashCode(): Int {
var result = index
result = 31 * result + type.hashCode()
return result
}
}

data class JcLocalVar(override val name: String, override val type: JcType) : JcLocal {
/**
* @param name isn't considered in `equals` and `hashcode`
*/
data class JcLocalVar(val index: Int, override val name: String, override val type: JcType) : JcLocal {
override fun toString(): String = name

override fun <T> accept(visitor: JcExprVisitor<T>): T {
return visitor.visitJcLocalVar(this)
}

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as JcLocalVar

if (index != other.index) return false
if (type != other.type) return false

return true
}

override fun hashCode(): Int {
var result = index
result = 31 * result + type.hashCode()
return result
}
}

interface JcComplexValue : JcValue
Expand Down
44 changes: 43 additions & 1 deletion jacodb-api/src/main/kotlin/org/jacodb/api/cfg/JcRawInst.kt
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,9 @@ data class JcRawThis(override val typeName: TypeName) : JcRawSimpleValue {
}
}

/**
* @param name isn't considered in `equals` and `hashcode`
*/
data class JcRawArgument(val index: Int, override val name: String, override val typeName: TypeName) : JcRawLocal {
companion object {
@JvmStatic
Expand All @@ -822,14 +825,53 @@ data class JcRawArgument(val index: Int, override val name: String, override val
override fun <T> accept(visitor: JcRawExprVisitor<T>): T {
return visitor.visitJcRawArgument(this)
}

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as JcRawArgument

if (index != other.index) return false
if (typeName != other.typeName) return false

return true
}

override fun hashCode(): Int {
var result = index
result = 31 * result + typeName.hashCode()
return result
}
}

data class JcRawLocalVar(override val name: String, override val typeName: TypeName) : JcRawLocal {
/**
* @param name isn't considered in `equals` and `hashcode`
*/
data class JcRawLocalVar(val index: Int, override val name: String, override val typeName: TypeName) : JcRawLocal {
override fun toString(): String = name

override fun <T> accept(visitor: JcRawExprVisitor<T>): T {
return visitor.visitJcRawLocalVar(this)
}

override fun equals(other: Any?): Boolean {
if (this === other) return true
if (javaClass != other?.javaClass) return false

other as JcRawLocalVar

if (index != other.index) return false
if (typeName != other.typeName) return false

return true
}

override fun hashCode(): Int {
var result = index
result = 31 * result + typeName.hashCode()
return result
}
}

sealed interface JcRawComplexValue : JcRawValue
Expand Down
4 changes: 2 additions & 2 deletions jacodb-core/src/main/kotlin/org/jacodb/impl/JcDatabaseImpl.kt
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ class JcDatabaseImpl(

private fun List<JcClasspathFeature>?.appendBuiltInFeatures(): List<JcClasspathFeature> {
if (this != null && any { it is ClasspathCache }) {
return this + listOf(KotlinMetadata, MethodInstructionsFeature)
return this + listOf(KotlinMetadata, MethodInstructionsFeature(settings.keepLocalVariableNames))
}
return listOf(ClasspathCache(settings.cacheSettings), KotlinMetadata, MethodInstructionsFeature) + orEmpty()
return listOf(ClasspathCache(settings.cacheSettings), KotlinMetadata, MethodInstructionsFeature(settings.keepLocalVariableNames)) + orEmpty()
}

override suspend fun classpath(dirOrJars: List<File>, features: List<JcClasspathFeature>?): JcClasspath {
Expand Down
7 changes: 7 additions & 0 deletions jacodb-core/src/main/kotlin/org/jacodb/impl/JcSettings.kt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ class JcSettings {

var persistentClearOnStart: Boolean? = null

var keepLocalVariableNames: Boolean = false
private set

/** jar files which should be loaded right after database is created */
var predefinedDirOrJars: List<File> = persistentListOf()
private set
Expand Down Expand Up @@ -101,6 +104,10 @@ class JcSettings {
predefinedDirOrJars = (predefinedDirOrJars + files).toPersistentList()
}

fun keepLocalVariableNames() {
keepLocalVariableNames = true
}

/**
* builder for watching file system changes
* @param delay - delay between syncs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@ import org.jacodb.impl.cfg.VirtualMethodRefImpl
import org.jacodb.impl.cfg.methodRef
import kotlin.collections.set

class StringConcatSimplifierTransformer(classpath: JcClasspath, private val list: JcInstList<JcInst>) : DefaultJcInstVisitor<JcInst> {
class StringConcatSimplifierTransformer(classpath: JcClasspath, private val list: JcInstList<JcInst>) :
DefaultJcInstVisitor<JcInst> {

override val defaultInstHandler: (JcInst) -> JcInst
get() = { it }
Expand All @@ -39,6 +40,10 @@ class StringConcatSimplifierTransformer(classpath: JcClasspath, private val list

private val stringType = classpath.findTypeOrNull<String>() as JcClassType

private var localCounter = list
.flatMap { it.values.filterIsInstance<JcLocalVar>() }
.maxOfOrNull { it.index }?.plus(1) ?: 0

fun transform(): JcInstList<JcInst> {
var changed = false
for (inst in list) {
Expand Down Expand Up @@ -101,7 +106,7 @@ class StringConcatSimplifierTransformer(classpath: JcClasspath, private val list
it.name == "toString" && it.parameters.size == 1 && it.parameters.first().type == value.type
}
val toStringExpr = JcStaticCallExpr(method.methodRef(), listOf(value))
val assignment = JcLocalVar("${value}String", stringType)
val assignment = JcLocalVar(localCounter++, "${value}String", stringType)
instList += JcAssignInst(inst.location, assignment, toStringExpr)
assignment
}
Expand All @@ -114,7 +119,7 @@ class StringConcatSimplifierTransformer(classpath: JcClasspath, private val list
}
val methodRef = VirtualMethodRefImpl.of(boxedType, method)
val toStringExpr = JcVirtualCallExpr(methodRef, value, emptyList())
val assignment = JcLocalVar("${value}String", stringType)
val assignment = JcLocalVar(localCounter++, "${value}String", stringType)
instList += JcAssignInst(inst.location, assignment, toStringExpr)
assignment
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class JcMethodImpl(
internal fun parameterTypeAnnotationInfos(parameterIndex: Int): List<AnnotationInfo> =
methodInfo.annotations.filter {
it.typeRef != null && TypeReference(it.typeRef).sort == TypeReference.METHOD_FORMAL_PARAMETER
&& TypeReference(it.typeRef).formalParameterIndex == parameterIndex
&& TypeReference(it.typeRef).formalParameterIndex == parameterIndex
}

override val description get() = methodInfo.desc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class JcInstListBuilder(val method: JcMethod,val instList: JcInstList<JcRawInst>
inst.lhv.let { unprocessedLhv ->
if (unprocessedLhv is JcRawLocalVar && unprocessedLhv.typeName == UNINIT_THIS) {
convertedLocalVars.getOrPut(unprocessedLhv) {
JcRawLocalVar(unprocessedLhv.name, inst.rhv.typeName)
JcRawLocalVar(unprocessedLhv.index, unprocessedLhv.name, inst.rhv.typeName)
}
} else {
unprocessedLhv
Expand Down Expand Up @@ -338,8 +338,8 @@ class JcInstListBuilder(val method: JcMethod,val instList: JcInstList<JcRawInst>

override fun visitJcRawLocalVar(value: JcRawLocalVar): JcExpr =
convertedLocalVars[value]?.let { replacementForLocalVar ->
JcLocalVar(replacementForLocalVar.name, replacementForLocalVar.typeName.asType())
} ?: JcLocalVar(value.name, value.typeName.asType())
JcLocalVar(replacementForLocalVar.index, replacementForLocalVar.name, replacementForLocalVar.typeName.asType())
} ?: JcLocalVar(value.index, value.name, value.typeName.asType())

override fun visitJcRawFieldRef(value: JcRawFieldRef): JcExpr {
val type = value.declaringClass.asType() as JcClassType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ open class JcInstListImpl<INST>(
else -> " $it"
}
}

}

class JcMutableInstListImpl<INST>(instructions: List<INST>) : JcInstListImpl<INST>(instructions),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package org.jacodb.impl.cfg

import org.jacodb.api.JcMethod
import org.jacodb.api.JcParameter
import org.jacodb.api.PredefinedPrimitives
import org.jacodb.api.TypeName
import org.jacodb.api.cfg.*
Expand All @@ -26,6 +27,7 @@ import org.objectweb.asm.Opcodes
import org.objectweb.asm.Opcodes.H_GETSTATIC
import org.objectweb.asm.Type
import org.objectweb.asm.tree.*
import java.util.ArrayList

private val PredefinedPrimitives.smallIntegers get() = setOf(Boolean, Byte, Char, Short, Int)

Expand Down Expand Up @@ -116,12 +118,26 @@ class MethodNodeBuilder(
}

private fun initializeFrame(method: JcMethod) {
var staticInc = 0
if (!method.isStatic) {
val thisRef = JcRawThis(method.enclosingClass.name.typeName())
locals[thisRef] = localIndex++
staticInc = 1
}

val variables = method.asmNode().localVariables.orEmpty().sortedBy(LocalVariableNode::index)

fun getName(parameter: JcParameter): String? {
val idx = parameter.index + staticInc
return if (idx < variables.size) {
variables[idx].name
} else {
parameter.name
}
}

for (parameter in method.parameters) {
val argument = JcRawArgument.of(parameter.index, parameter.name, parameter.type)
val argument = JcRawArgument.of(parameter.index, getName(parameter), parameter.type)
locals[argument] = localIndex
if (argument.typeName.isDWord) localIndex += 2
else localIndex++
Expand Down
Loading

0 comments on commit 389ea0f

Please sign in to comment.