Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] First step in supporting type inference, based on the usage of sq_concat #207

Merged
merged 6 commits into from
Aug 16, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,11 @@ JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasNbPositiv
return type->tp_as_number && type->tp_as_number->nb_positive;
}

JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasSqConcat(JNIEnv *env, jobject _, jlong type_ref) {
QUERY_TYPE_HAS_PREFIX
return type->tp_as_sequence && type->tp_as_sequence->sq_concat;
}

JNIEXPORT jint JNICALL Java_org_usvm_interpreter_CPythonAdapter_typeHasSqLength(JNIEnv *env, jobject _, jlong type_ref) {
QUERY_TYPE_HAS_PREFIX
return type->tp_as_sequence && type->tp_as_sequence->sq_length;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@ val sampleStringFunction = StringProgramProvider(
/**
* Sample of a function that cannot be covered right now.
* */
val listConcatProgram = StringProgramProvider(
val tupleConcatProgram = StringProgramProvider(
"""
def list_concat(x):
y = x + [1]
if len(y[::-1]) == 5:
return 1
return 2
def tuple_concat(x, y):
z = x + y
return z + (1, 2, 3)
""".trimIndent(),
"list_concat",
) { listOf(PythonAnyType) }
"tuple_concat",
) { listOf(PythonAnyType, PythonAnyType) }
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,20 @@ class SimpleTypeInferenceTest: PythonTestRunnerForPrimitiveProgram("SimpleTypeIn
)
}

@Test
fun testListConcatUsage() {
check2WithConcreteRun(
constructFunction("list_concat_usage", List(2) { PythonAnyType }),
ignoreNumberOfAnalysisResults,
standardConcolicAndConcreteChecks,
/* invariants = */ emptyList(),
/* propertiesToDiscover = */ listOf(
{ _, _, res -> res.selfTypeName == "AssertionError" },
{ _, _, res -> res.repr == "None" }
)
)
}

@Test
fun testLenUsage() {
check1WithConcreteRun(
Expand Down
6 changes: 6 additions & 0 deletions usvm-python/src/test/resources/samples/SimpleTypeInference.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ def isinstance_sample(x):
return "Not reachable"


def list_concat_usage(x, y):
z = x + y
z += []
assert z


def len_usage(x):
if len(x) == 5:
return 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ public class CPythonAdapter {
public native int typeHasNbMatrixMultiply(long type);
public native int typeHasNbNegative(long type);
public native int typeHasNbPositive(long type);
public native int typeHasSqConcat(long type);
public native int typeHasSqLength(long type);
public native int typeHasMpLength(long type);
public native int typeHasMpSubscript(long type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ object ConcretePythonInterpreter {
val typeHasNbMatrixMultiply = createTypeQuery { pythonAdapter.typeHasNbMatrixMultiply(it) }
val typeHasNbNegative = createTypeQuery { pythonAdapter.typeHasNbNegative(it) }
val typeHasNbPositive = createTypeQuery { pythonAdapter.typeHasNbPositive(it) }
val typeHasSqConcat = createTypeQuery { pythonAdapter.typeHasSqConcat(it) }
val typeHasSqLength = createTypeQuery { pythonAdapter.typeHasSqLength(it) }
val typeHasMpLength = createTypeQuery { pythonAdapter.typeHasMpLength(it) }
val typeHasMpSubscript = createTypeQuery { pythonAdapter.typeHasMpSubscript(it) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import org.usvm.machine.types.HasNbMultiply
import org.usvm.machine.types.HasNbNegative
import org.usvm.machine.types.HasNbPositive
import org.usvm.machine.types.HasNbSubtract
import org.usvm.machine.types.HasSqConcat
import org.usvm.machine.types.HasSqLength
import org.usvm.machine.types.HasTpCall
import org.usvm.machine.types.HasTpGetattro
Expand All @@ -37,7 +38,10 @@ fun nbAddKt(
context.ctx
) {
context.curState ?: return
pyAssert(context, left.evalIsSoft(context, HasNbAdd) or right.evalIsSoft(context, HasNbAdd))
val nbAdd = left.evalIsSoft(context, HasNbAdd) or right.evalIsSoft(context, HasNbAdd)
val sqConcat = left.evalIsSoft(context, HasSqConcat) and right.evalIsSoft(context, HasSqConcat)
pyAssert(context, context.ctx.mkImplies(nbAdd.not(), sqConcat))
tochilinak marked this conversation as resolved.
Show resolved Hide resolved
pyFork(context, nbAdd)
}

fun nbSubtractKt(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ object HasNbPositive : TypeProtocol() {
ConcretePythonInterpreter.typeHasNbPositive(type.asObject)
}

object HasSqConcat : TypeProtocol() {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasSqConcat(type.asObject)
}

object HasSqLength : TypeProtocol() {
override fun acceptsConcrete(type: ConcretePythonType): Boolean =
ConcretePythonInterpreter.typeHasSqLength(type.asObject)
Expand Down
Loading