Skip to content

Commit

Permalink
Add minor type inference fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lipen committed Dec 27, 2024
1 parent c599ea7 commit 602acc7
Show file tree
Hide file tree
Showing 26 changed files with 881 additions and 561 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.jacodb.ets.base.EtsParameterRef
import org.jacodb.ets.base.EtsStaticFieldRef
import org.jacodb.ets.base.EtsThis
import org.jacodb.ets.base.EtsValue
import org.jacodb.ets.model.EtsClassSignature

data class AccessPath(val base: AccessPathBase, val accesses: List<Accessor>) {
operator fun plus(accessor: Accessor) = AccessPath(base, accesses + accessor)
Expand Down Expand Up @@ -44,8 +45,8 @@ sealed interface AccessPathBase {
override fun toString(): String = "<this>"
}

object Static : AccessPathBase {
override fun toString(): String = "<static>"
data class Static(val clazz: EtsClassSignature) : AccessPathBase {
override fun toString(): String = "static(${clazz.name})"
}

data class Arg(val index: Int) : AccessPathBase {
Expand All @@ -54,6 +55,34 @@ sealed interface AccessPathBase {

data class Local(val name: String) : AccessPathBase {
override fun toString(): String = "local($name)"

fun tryGetOrdering(): Int? {
if (name.startsWith("%")) {
val ix = name.substring(1).toIntOrNull()
if (ix != null) {
return ix
}
}
if (name.startsWith("\$v")) {
val ix = name.substring(2).toIntOrNull()
if (ix != null) {
return 10_000 + ix
}
}
if (name.startsWith("\$temp")) {
val ix = name.substring(5).toIntOrNull()
if (ix != null) {
return 20_000 + ix
}
}
if (name.startsWith("_tmp")) {
val ix = name.substring(4).toIntOrNull()
if (ix != null) {
return 30_000 + ix
}
}
return null
}
}

data class Const(val constant: EtsConstant) : AccessPathBase {
Expand Down Expand Up @@ -86,7 +115,10 @@ fun EtsEntity.toPathOrNull(): AccessPath? = when (this) {
it + FieldAccessor(field.name)
}

is EtsStaticFieldRef -> AccessPath(AccessPathBase.Static, listOf(FieldAccessor(field.name)))
is EtsStaticFieldRef -> {
val base = AccessPathBase.Static(field.enclosingClass)
AccessPath(base, listOf(FieldAccessor(field.name)))
}

is EtsCastExpr -> arg.toPathOrNull()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class BackwardFlowFunctions(
val graph: ApplicationGraph<EtsMethod, EtsStmt>,
val dominators: (EtsMethod) -> GraphDominators<EtsStmt>,
val savedTypes: MutableMap<EtsType, MutableList<EtsTypeFact>>,
val doAddKnownTypes: Boolean = true,
) : FlowFunctions<BackwardTypeDomainFact, EtsMethod, EtsStmt> {

// private val aliasesCache: MutableMap<EtsMethod, Map<EtsStmt, Pair<AliasInfo, AliasInfo>>> = hashMapOf()
Expand Down Expand Up @@ -200,9 +201,21 @@ class BackwardFlowFunctions(
// Case `return x`
// ∅ |= x:unknown
if (current is EtsReturnStmt) {
val variable = current.returnValue?.toBase()
if (variable != null) {
result += TypedVariable(variable, EtsTypeFact.UnknownEtsTypeFact)
val returnValue = current.returnValue
if (returnValue != null) {
val variable = returnValue.toBase()
val type = if (doAddKnownTypes) {
EtsTypeFact.from(returnValue.type).let {
if (it is EtsTypeFact.AnyEtsTypeFact) {
EtsTypeFact.UnknownEtsTypeFact
} else {
it
}
}
} else {
EtsTypeFact.UnknownEtsTypeFact
}
result += TypedVariable(variable, type)
}
}

Expand All @@ -223,10 +236,23 @@ class BackwardFlowFunctions(
if (rhv.accesses.isEmpty()) {
// Case `x... := y`
// ∅ |= y:unknown
result += TypedVariable(y, EtsTypeFact.UnknownEtsTypeFact)
val type = if (doAddKnownTypes) {
EtsTypeFact.from(current.rhv.type).let { it ->
if (it is EtsTypeFact.AnyEtsTypeFact) {
EtsTypeFact.UnknownEtsTypeFact
} else {
it
}
}
} else {
EtsTypeFact.UnknownEtsTypeFact
}
result += TypedVariable(y, type)
} else {
// Case `x := y.f` OR `x := y[i]`

// TODO: handle known (real) type

check(rhv.accesses.size == 1)
when (val accessor = rhv.accesses.single()) {
// Case `x := y.f`
Expand Down Expand Up @@ -359,6 +385,11 @@ class BackwardFlowFunctions(
cls = null,
properties = mapOf(a.name to fact.type)
)
// val realType = EtsTypeFact.from(current.rhv.type)
// val type = newType.intersect(realType) ?: run {
// logger.warn { "Empty intersection of fact and real type: $newType & $realType" }
// newType
// }
result += TypedVariable(y, type).withTypeGuards(current)
// aliases: +|= z:{f:T}
// for (z in preAliases.getAliases(AccessPath(y, emptyList()))) {
Expand All @@ -373,6 +404,11 @@ class BackwardFlowFunctions(
// x:T |= x:T (keep) + y:Array<T>
val y = rhv.base
val type = EtsTypeFact.ArrayEtsTypeFact(elementType = fact.type)
// val realType = EtsTypeFact.from(current.rhv.type)
// val type = newType.intersect(realType) ?: run {
// logger.warn { "Empty intersection of fact and real type: $newType & $realType" }
// newType
// }
val newFact = TypedVariable(y, type).withTypeGuards(current)
return listOf(fact, newFact)
}
Expand All @@ -386,11 +422,11 @@ class BackwardFlowFunctions(
// Case `x.f := y`
is FieldAccessor -> {
if (fact.type is EtsTypeFact.UnionEtsTypeFact) {
TODO("Support union type for x.f := y in BW-sequent")
// TODO("Support union type for x.f := y in BW-sequent")
}

if (fact.type is EtsTypeFact.IntersectionEtsTypeFact) {
TODO("Support intersection type for x.f := y in BW-sequent")
// TODO("Support intersection type for x.f := y in BW-sequent")
}

// x:primitive |= x:primitive (pass)
Expand All @@ -412,11 +448,11 @@ class BackwardFlowFunctions(
// Case `x[i] := y`
is ElementAccessor -> {
if (fact.type is EtsTypeFact.UnionEtsTypeFact) {
TODO("Support union type for x[i] := y in BW-sequent")
// TODO("Support union type for x[i] := y in BW-sequent")
}

if (fact.type is EtsTypeFact.IntersectionEtsTypeFact) {
TODO("Support intersection type for x[i] := y in BW-sequent")
// TODO("Support intersection type for x[i] := y in BW-sequent")
}

// x:Array<T> |= x:Array<T> (pass)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import org.jacodb.ets.base.EtsUndefinedType
import org.jacodb.ets.base.EtsUnionType
import org.jacodb.ets.base.EtsUnknownType
import org.jacodb.ets.base.INSTANCE_INIT_METHOD_NAME
import org.usvm.dataflow.ts.util.Globals

private val logger = KotlinLogging.logger {}

Expand Down Expand Up @@ -44,7 +45,9 @@ sealed interface EtsTypeFact {
}
}

fun intersect(other: EtsTypeFact): EtsTypeFact? {
fun intersect(other: EtsTypeFact?): EtsTypeFact? {
if (other == null) return this

if (this == other) return this

if (other is UnknownEtsTypeFact) return this
Expand Down Expand Up @@ -151,10 +154,41 @@ sealed interface EtsTypeFact {
override fun toString(): String = "Array<$elementType>"
}

data class ObjectEtsTypeFact(
@ConsistentCopyVisibility
data class ObjectEtsTypeFact private constructor(
val cls: EtsType?,
val properties: Map<String, EtsTypeFact>,
) : BasicType {
companion object {
operator fun invoke(
cls: EtsType?,
properties: Map<String, EtsTypeFact>,
): ObjectEtsTypeFact {
if (cls is EtsUnclearRefType && cls.name == "Object") {
return ObjectEtsTypeFact(null, properties)
}
return ObjectEtsTypeFact(cls, properties)
}
}

fun getRealProperties(): Map<String, EtsTypeFact> {
val scene = Globals.scene ?: return properties
if (cls == null || cls !is EtsClassType) {
return properties
}
val clazz = scene.projectAndSdkClasses.firstOrNull { it.signature == cls.signature }
?: return properties
val props = properties.toMutableMap()
clazz.methods.forEach { m ->
props.merge(m.name, FunctionEtsTypeFact) { old, new ->
old.intersect(new).also {
if (it == null) logger.warn { "Empty intersection: $old & $new" }
}
}
}
return props
}

override fun toString(): String {
val clsName = cls?.typeName?.takeUnless { it.startsWith(ANONYMOUS_CLASS_PREFIX) } ?: "Object"
val funProps = properties.entries
Expand Down Expand Up @@ -360,23 +394,26 @@ sealed interface EtsTypeFact {
return mkIntersectionType(guardedType, other)
}

private fun tryIntersect(cls1: EtsType?, cls2: EtsType?): EtsType? {
if (cls1 == cls2) return cls1
if (cls1 == null) return cls2
if (cls2 == null) return cls1
// TODO: isSubtype
return null
}

private fun intersect(obj1: ObjectEtsTypeFact, obj2: ObjectEtsTypeFact): EtsTypeFact? {
val intersectionProperties = obj1.properties.toMutableMap()
for ((property, type) in obj2.properties) {
val intersectionProperties = obj1.getRealProperties().toMutableMap()
for ((property, type) in obj2.getRealProperties()) {
val currentType = intersectionProperties[property]
if (currentType == null) {
intersectionProperties[property] = type
continue
} else {
intersectionProperties[property] = currentType.intersect(type)
?: return null
}

intersectionProperties[property] = currentType.intersect(type) ?: return null
}

val intersectionCls = if (obj1.cls != null && obj2.cls != null) {
obj1.cls.takeIf { it == obj2.cls }
} else {
obj1.cls ?: obj2.cls
}
val intersectionCls = tryIntersect(obj1.cls, obj2.cls)
return ObjectEtsTypeFact(intersectionCls, intersectionProperties)
}

Expand All @@ -391,7 +428,7 @@ sealed interface EtsTypeFact {
type
}

return ObjectEtsTypeFact(obj.cls, intersectionProperties)
return ObjectEtsTypeFact(null, intersectionProperties)
}

private fun union(unionType: UnionEtsTypeFact, other: EtsTypeFact): EtsTypeFact {
Expand Down Expand Up @@ -504,7 +541,7 @@ sealed interface EtsTypeFact {
is EtsUnclearRefType -> ObjectEtsTypeFact(type, emptyMap())
// is EtsGenericType -> TODO()
else -> {
logger.error { "Unsupported type: $type" }
logger.warn { "Unsupported type: $type" }
UnknownEtsTypeFact
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.usvm.dataflow.ts.infer

import org.jacodb.ets.base.EtsNopStmt
import org.jacodb.ets.base.EtsStmt
import org.jacodb.ets.base.EtsType
import org.jacodb.ets.model.EtsMethod
Expand All @@ -10,7 +11,7 @@ import org.usvm.dataflow.ts.graph.EtsApplicationGraph

class ForwardAnalyzer(
val graph: EtsApplicationGraph,
methodInitialTypes: Map<EtsMethod, EtsMethodTypeFacts>,
methodInitialTypes: Map<EtsMethod, Map<AccessPathBase, EtsTypeFact>>,
typeInfo: Map<EtsType, EtsTypeFact>,
doAddKnownTypes: Boolean = true,
) : Analyzer<ForwardTypeDomainFact, AnalyzerEvent, EtsMethod, EtsStmt> {
Expand All @@ -27,18 +28,18 @@ class ForwardAnalyzer(
override fun handleNewEdge(edge: Edge<ForwardTypeDomainFact, EtsStmt>): List<AnalyzerEvent> {
val (startVertex, currentVertex) = edge
val (current, currentFact) = currentVertex

val method = graph.methodOf(current)
val currentIsExit = current in graph.exitPoints(method)

if (!currentIsExit) return emptyList()

return listOf(
ForwardSummaryAnalyzerEvent(
method = method,
initialVertex = startVertex,
exitVertex = currentVertex,
val currentIsExit = current in graph.exitPoints(method) ||
(current is EtsNopStmt && graph.successors(current).none())
if (currentIsExit) {
return listOf(
ForwardSummaryAnalyzerEvent(
method = method,
initialVertex = startVertex,
exitVertex = currentVertex,
)
)
)
}
return emptyList()
}
}
Loading

0 comments on commit 602acc7

Please sign in to comment.