diff --git a/usvm-jvm/build.gradle.kts b/usvm-jvm/build.gradle.kts index 0cf77dc8ac..91e4b6d4e6 100644 --- a/usvm-jvm/build.gradle.kts +++ b/usvm-jvm/build.gradle.kts @@ -10,6 +10,12 @@ val samples by sourceSets.creating { } } +val `samples-jdk11` by sourceSets.creating { + java { + srcDir("src/samples-jdk11/java") + } +} + val `sample-approximations` by sourceSets.creating { java { srcDir("src/sample-approximations/java") @@ -92,9 +98,20 @@ val `usvm-api-jar` = tasks.register("usvm-api-jar") { val testSamples by configurations.creating val testSamplesWithApproximations by configurations.creating +val compileSamplesJdk11 = tasks.register("compileSamplesJdk11") { + sourceCompatibility = JavaVersion.VERSION_11.toString() + targetCompatibility = JavaVersion.VERSION_11.toString() + + source = `samples-jdk11`.java + classpath = `samples-jdk11`.compileClasspath + options.sourcepath = `samples-jdk11`.java + destinationDirectory = `samples-jdk11`.java.destinationDirectory +} + dependencies { testSamples(samples.output) testSamples(`usvm-api`.output) + testSamples(files(`samples-jdk11`.java.destinationDirectory)) testSamplesWithApproximations(samples.output) testSamplesWithApproximations(`usvm-api`.output) @@ -104,7 +121,7 @@ dependencies { tasks.withType { dependsOn(`usvm-api-jar`) - dependsOn(testSamples, testSamplesWithApproximations) + dependsOn(compileSamplesJdk11, testSamples, testSamplesWithApproximations) val usvmApiJarPath = `usvm-api-jar`.get().outputs.files.singleFile val usvmApproximationJarPath = approximations.resolvedConfiguration.files.single() diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt index 2adba5cf61..2c1000e25e 100644 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcExprResolver.kt @@ -103,6 +103,9 @@ import org.usvm.machine.interpreter.statics.JcStaticFieldRegionId import org.usvm.machine.interpreter.statics.JcStaticFieldsMemoryRegion import org.usvm.machine.interpreter.statics.isInitialized import org.usvm.machine.interpreter.statics.markAsInitialized +import org.usvm.machine.interpreter.transformers.JcMultiDimArrayAllocationTransformer +import org.usvm.machine.interpreter.transformers.JcStringConcatTransformer +import org.usvm.machine.logger import org.usvm.machine.operator.JcBinaryOperator import org.usvm.machine.operator.JcUnaryOperator import org.usvm.machine.operator.ensureBvExpr @@ -408,7 +411,11 @@ class JcExprResolver( } override fun visitJcDynamicCallExpr(expr: JcDynamicCallExpr): UExpr? = - resolveInvoke( + apply { + if (JcStringConcatTransformer.methodIsStringConcat(expr.method.method)) { + logger.warn { "JcStringConcatTransformer should be used to process string concatenation" } + } + }.resolveInvoke( expr.method, instanceExpr = null, argumentExprs = { expr.callSiteArgs }, diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcMultiDimArrayAllocationTransformer.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcMultiDimArrayAllocationTransformer.kt deleted file mode 100644 index 52306a7db6..0000000000 --- a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/JcMultiDimArrayAllocationTransformer.kt +++ /dev/null @@ -1,246 +0,0 @@ -package org.usvm.machine.interpreter - -import org.jacodb.api.jvm.JcArrayType -import org.jacodb.api.jvm.JcClasspath -import org.jacodb.api.jvm.JcInstExtFeature -import org.jacodb.api.jvm.JcMethod -import org.jacodb.api.jvm.JcType -import org.jacodb.api.jvm.cfg.JcAddExpr -import org.jacodb.api.jvm.cfg.JcArrayAccess -import org.jacodb.api.jvm.cfg.JcAssignInst -import org.jacodb.api.jvm.cfg.JcCatchInst -import org.jacodb.api.jvm.cfg.JcExpr -import org.jacodb.api.jvm.cfg.JcExprVisitor -import org.jacodb.api.jvm.cfg.JcGeExpr -import org.jacodb.api.jvm.cfg.JcGotoInst -import org.jacodb.api.jvm.cfg.JcIfInst -import org.jacodb.api.jvm.cfg.JcInst -import org.jacodb.api.jvm.cfg.JcInstList -import org.jacodb.api.jvm.cfg.JcInstLocation -import org.jacodb.api.jvm.cfg.JcInstRef -import org.jacodb.api.jvm.cfg.JcInt -import org.jacodb.api.jvm.cfg.JcLocalVar -import org.jacodb.api.jvm.cfg.JcNewArrayExpr -import org.jacodb.api.jvm.cfg.JcValue -import org.jacodb.api.jvm.ext.boolean -import org.jacodb.api.jvm.ext.int -import org.jacodb.impl.cfg.JcInstListImpl -import org.jacodb.impl.cfg.JcInstLocationImpl -import kotlin.contracts.ExperimentalContracts -import kotlin.contracts.InvocationKind -import kotlin.contracts.contract - -object JcMultiDimArrayAllocationTransformer : JcInstExtFeature { - override fun transformInstList(method: JcMethod, list: JcInstList): JcInstList { - val multiDimArrayAllocations = list.mapNotNull { inst -> - val assignInst = inst as? JcAssignInst ?: return@mapNotNull null - val arrayAllocation = assignInst.rhv as? JcNewArrayExpr ?: return@mapNotNull null - if (arrayAllocation.dimensions.size == 1) return@mapNotNull null - assignInst to arrayAllocation - } - - if (multiDimArrayAllocations.isEmpty()) return list - - val modifiedInstructions = list.instructions.toMutableList() - val maxLocalVarIndex = modifiedInstructions.maxOfOrNull { LocalVarMaxIndexFinder.find(it.operands) } ?: -1 - - var generatedLocalVarIndex = maxLocalVarIndex + 1 - val modifiedLocationIndices = hashMapOf>() - - for ((assignInst, arrayAllocation) in multiDimArrayAllocations) { - val originalLocation = assignInst.location - val blockGenerator = ArrayInitializationBlockGenerator( - method.enclosingClass.classpath, - originalLocation, modifiedInstructions, generatedLocalVarIndex - ) - - blockGenerator.generateBlock(assignInst.lhv, arrayAllocation) - blockGenerator.generateBlockJump() - - generatedLocalVarIndex = blockGenerator.localVarIndex - val generatedLocations = blockGenerator.generatedLocations - modifiedLocationIndices[originalLocation.index] = generatedLocations.map { it.index } - } - - fixCatchBlockThrowers(modifiedInstructions, modifiedLocationIndices) - - return JcInstListImpl(modifiedInstructions) - } - - /** - * Since we generate multiple instructions instead of a single one, - * we must ensure that all catchers of the original instruction will - * catch exceptions of generated instructions. - * */ - private fun fixCatchBlockThrowers( - instructions: MutableList, - modifiedLocationIndices: Map> - ) { - for (i in instructions.indices) { - val instruction = instructions[i] - if (instruction !is JcCatchInst) continue - - val throwers = instruction.throwers.toMutableList() - for (throwerIdx in throwers.indices) { - val thrower = throwers[throwerIdx] - val generatedLocations = modifiedLocationIndices[thrower.index] ?: continue - generatedLocations.mapTo(throwers) { JcInstRef(it) } - } - - instructions[i] = with(instruction) { - JcCatchInst(location, throwable, throwableTypes, throwers) - } - } - } - - private class ArrayInitializationBlockGenerator( - private val cp: JcClasspath, - private val originalLocation: JcInstLocation, - private val instructions: MutableList, - initialLocalVarIndex: Int, - ) { - var localVarIndex: Int = initialLocalVarIndex - val generatedLocations = mutableListOf() - - fun nextLocalVar(name: String, type: JcType) = JcLocalVar(localVarIndex++, name, type) - - /** - * original: - * result = new T[d0][d1][d2] - * - * rewrited: - * a0: T[][][] = new T[d0][][] - * i0 = 0 - * INIT_0_START: - * if (i0 >= d0) goto INIT_0_END - * - * a1: T[][] = new T[d1][] - * i1 = 0 - * - * INIT_1_START: - * if (i1 >= d1) goto INIT_1_END - * - * a2: T[] = new T[d2] - * - * a1[i1] = a2 - * i1++ - * goto INIT_1_START - * - * INIT_1_END: - * a0[i0] = a1 - * i0++ - * goto INIT_0_START - * - * INIT_0_END: - * result = a0 - * */ - fun generateBlock(resultVariable: JcValue, arrayAllocation: JcNewArrayExpr) { - val type = arrayAllocation.type as? JcArrayType - ?: error("Incorrect array allocation: $arrayAllocation") - - val arrayVar = generateBlock(type, arrayAllocation.dimensions, dimensionIdx = 0) - addInstruction { loc -> - JcAssignInst(loc, resultVariable, arrayVar) - } - } - - private fun generateBlock(type: JcArrayType, dimensions: List, dimensionIdx: Int): JcValue { - val dimension = dimensions[dimensionIdx] - val arrayVar = nextLocalVar("a_${originalLocation.index}_$dimensionIdx", type) - - addInstruction { loc -> - JcAssignInst(loc, arrayVar, JcNewArrayExpr(type, listOf(dimension))) - } - - if (dimensionIdx == dimensions.lastIndex) return arrayVar - - val initializerIdxVar = nextLocalVar("i_${originalLocation.index}_$dimensionIdx", cp.int) - addInstruction { loc -> - JcAssignInst(loc, initializerIdxVar, JcInt(0, cp.int)) - } - - val initStartLoc: JcInstLocation - addInstruction { loc -> - initStartLoc = loc - - val cond = JcGeExpr(cp.boolean, initializerIdxVar, dimension) - val nextInst = JcInstRef(loc.index + 1) - JcIfInst(loc, cond, END_LABEL_STUB, nextInst) - } - - val nestedArrayType = type.elementType as? JcArrayType - ?: error("Incorrect array type: $type") - - val nestedArrayVar = generateBlock(nestedArrayType, dimensions, dimensionIdx + 1) - - addInstruction { loc -> - val arrayElement = JcArrayAccess(arrayVar, initializerIdxVar, nestedArrayType) - JcAssignInst(loc, arrayElement, nestedArrayVar) - } - - addInstruction { loc -> - JcAssignInst(loc, initializerIdxVar, JcAddExpr(cp.int, initializerIdxVar, JcInt(1, cp.int))) - } - - val initEndLoc: JcInstLocation - addInstruction { loc -> - initEndLoc = loc - JcGotoInst(loc, JcInstRef(initStartLoc.index)) - } - - val blockStartInst = instructions[initStartLoc.index] as JcIfInst - val blockEnd = JcInstRef(initEndLoc.index + 1) - instructions[initStartLoc.index] = replaceEndLabelStub(blockStartInst, blockEnd) - - return arrayVar - } - - fun generateBlockJump() { - addInstruction { loc -> - JcGotoInst(loc, JcInstRef(originalLocation.index + 1)) - } - - val arrayInitializationBlockStart = JcInstRef(generatedLocations.first().index) - instructions[originalLocation.index] = JcGotoInst(originalLocation, arrayInitializationBlockStart) - } - - @OptIn(ExperimentalContracts::class) - private inline fun addInstruction(body: (JcInstLocation) -> JcInst) { - contract { - callsInPlace(body, InvocationKind.EXACTLY_ONCE) - } - - instructions.addInstruction(originalLocation) { loc -> - generatedLocations += loc - body(loc) - } - } - - companion object { - private val END_LABEL_STUB = JcInstRef(-1) - - private fun replaceEndLabelStub(inst: JcIfInst, replacement: JcInstRef): JcIfInst = with(inst) { - JcIfInst( - location, - condition, - if (trueBranch == END_LABEL_STUB) replacement else trueBranch, - if (falseBranch == END_LABEL_STUB) replacement else falseBranch, - ) - } - } - } - - private inline fun MutableList.addInstruction(origin: JcInstLocation, body: (JcInstLocation) -> JcInst) { - val index = size - val newLocation = JcInstLocationImpl(origin.method, index, origin.lineNumber) - val instruction = body(newLocation) - check(size == index) - add(instruction) - } - - private object LocalVarMaxIndexFinder : JcExprVisitor.Default { - override fun defaultVisitJcExpr(expr: JcExpr) = find(expr.operands) - override fun visitJcLocalVar(value: JcLocalVar) = value.index - fun find(expressions: Iterable): Int = expressions.maxOfOrNull { it.accept(this) } ?: -1 - } -} diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcMultiDimArrayAllocationTransformer.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcMultiDimArrayAllocationTransformer.kt new file mode 100644 index 0000000000..c4ea739d79 --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcMultiDimArrayAllocationTransformer.kt @@ -0,0 +1,159 @@ +package org.usvm.machine.interpreter.transformers + +import org.jacodb.api.jvm.JcArrayType +import org.jacodb.api.jvm.JcClasspath +import org.jacodb.api.jvm.JcInstExtFeature +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.cfg.JcAddExpr +import org.jacodb.api.jvm.cfg.JcArrayAccess +import org.jacodb.api.jvm.cfg.JcAssignInst +import org.jacodb.api.jvm.cfg.JcGeExpr +import org.jacodb.api.jvm.cfg.JcGotoInst +import org.jacodb.api.jvm.cfg.JcIfInst +import org.jacodb.api.jvm.cfg.JcInst +import org.jacodb.api.jvm.cfg.JcInstList +import org.jacodb.api.jvm.cfg.JcInstLocation +import org.jacodb.api.jvm.cfg.JcInstRef +import org.jacodb.api.jvm.cfg.JcInt +import org.jacodb.api.jvm.cfg.JcNewArrayExpr +import org.jacodb.api.jvm.cfg.JcValue +import org.jacodb.api.jvm.ext.boolean +import org.jacodb.api.jvm.ext.int +import org.usvm.machine.interpreter.transformers.JcSingleInstructionTransformer.BlockGenerationContext + +object JcMultiDimArrayAllocationTransformer : JcInstExtFeature { + override fun transformInstList(method: JcMethod, list: JcInstList): JcInstList { + val multiDimArrayAllocations = list.mapNotNull { inst -> + val assignInst = inst as? JcAssignInst ?: return@mapNotNull null + val arrayAllocation = assignInst.rhv as? JcNewArrayExpr ?: return@mapNotNull null + if (arrayAllocation.dimensions.size == 1) return@mapNotNull null + assignInst to arrayAllocation + } + + if (multiDimArrayAllocations.isEmpty()) return list + + val transformer = JcSingleInstructionTransformer(list) + for ((assignInst, arrayAllocation) in multiDimArrayAllocations) { + transformer.generateReplacementBlock(assignInst) { + generateBlock( + method.enclosingClass.classpath, + assignInst.lhv, arrayAllocation + ) + } + } + + return transformer.buildInstList() + } + + /** + * original: + * result = new T[d0][d1][d2] + * + * rewrited: + * a0: T[][][] = new T[d0][][] + * i0 = 0 + * INIT_0_START: + * if (i0 >= d0) goto INIT_0_END + * + * a1: T[][] = new T[d1][] + * i1 = 0 + * + * INIT_1_START: + * if (i1 >= d1) goto INIT_1_END + * + * a2: T[] = new T[d2] + * + * a1[i1] = a2 + * i1++ + * goto INIT_1_START + * + * INIT_1_END: + * a0[i0] = a1 + * i0++ + * goto INIT_0_START + * + * INIT_0_END: + * result = a0 + * */ + private fun BlockGenerationContext.generateBlock( + cp: JcClasspath, + resultVariable: JcValue, + arrayAllocation: JcNewArrayExpr + ) { + val type = arrayAllocation.type as? JcArrayType + ?: error("Incorrect array allocation: $arrayAllocation") + + val arrayVar = generateBlock(cp, type, arrayAllocation.dimensions, dimensionIdx = 0) + addInstruction { loc -> + JcAssignInst(loc, resultVariable, arrayVar) + } + } + + private fun BlockGenerationContext.generateBlock( + cp: JcClasspath, + type: JcArrayType, + dimensions: List, + dimensionIdx: Int + ): JcValue { + val dimension = dimensions[dimensionIdx] + val arrayVar = nextLocalVar("a_${originalLocation.index}_$dimensionIdx", type) + + addInstruction { loc -> + JcAssignInst(loc, arrayVar, JcNewArrayExpr(type, listOf(dimension))) + } + + if (dimensionIdx == dimensions.lastIndex) return arrayVar + + val initializerIdxVar = nextLocalVar("i_${originalLocation.index}_$dimensionIdx", cp.int) + addInstruction { loc -> + JcAssignInst(loc, initializerIdxVar, JcInt(0, cp.int)) + } + + val initStartLoc: JcInstLocation + addInstruction { loc -> + initStartLoc = loc + + val cond = JcGeExpr(cp.boolean, initializerIdxVar, dimension) + val nextInst = JcInstRef(loc.index + 1) + JcIfInst(loc, cond, END_LABEL_STUB, nextInst) + } + + val nestedArrayType = type.elementType as? JcArrayType + ?: error("Incorrect array type: $type") + + val nestedArrayVar = generateBlock(cp, nestedArrayType, dimensions, dimensionIdx + 1) + + addInstruction { loc -> + val arrayElement = JcArrayAccess(arrayVar, initializerIdxVar, nestedArrayType) + JcAssignInst(loc, arrayElement, nestedArrayVar) + } + + addInstruction { loc -> + JcAssignInst(loc, initializerIdxVar, JcAddExpr(cp.int, initializerIdxVar, JcInt(1, cp.int))) + } + + val initEndLoc: JcInstLocation + addInstruction { loc -> + initEndLoc = loc + JcGotoInst(loc, JcInstRef(initStartLoc.index)) + } + + replaceInstructionAtLocation(initStartLoc) { blockStartInst -> + val blockEnd = JcInstRef(initEndLoc.index + 1) + replaceEndLabelStub(blockStartInst as JcIfInst, blockEnd) + } + + return arrayVar + } + + private val END_LABEL_STUB = JcInstRef(-1) + + private fun replaceEndLabelStub(inst: JcIfInst, replacement: JcInstRef): JcIfInst = with(inst) { + JcIfInst( + location, + condition, + if (trueBranch == END_LABEL_STUB) replacement else trueBranch, + if (falseBranch == END_LABEL_STUB) replacement else falseBranch, + ) + } +} diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcSingleInstructionTransformer.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcSingleInstructionTransformer.kt new file mode 100644 index 0000000000..b788672e4e --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcSingleInstructionTransformer.kt @@ -0,0 +1,115 @@ +package org.usvm.machine.interpreter.transformers + +import org.jacodb.api.jvm.JcType +import org.jacodb.api.jvm.cfg.JcCatchInst +import org.jacodb.api.jvm.cfg.JcExpr +import org.jacodb.api.jvm.cfg.JcExprVisitor +import org.jacodb.api.jvm.cfg.JcGotoInst +import org.jacodb.api.jvm.cfg.JcInst +import org.jacodb.api.jvm.cfg.JcInstList +import org.jacodb.api.jvm.cfg.JcInstLocation +import org.jacodb.api.jvm.cfg.JcInstRef +import org.jacodb.api.jvm.cfg.JcLocalVar +import org.jacodb.impl.cfg.JcInstListImpl +import org.jacodb.impl.cfg.JcInstLocationImpl +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.InvocationKind +import kotlin.contracts.contract + +class JcSingleInstructionTransformer(originalInstructions: JcInstList) { + val mutableInstructions = originalInstructions.instructions.toMutableList() + private val maxLocalVarIndex = mutableInstructions.maxOfOrNull { LocalVarMaxIndexFinder.find(it.operands) } ?: -1 + + var generatedLocalVarIndex = maxLocalVarIndex + 1 + val modifiedLocationIndices = hashMapOf>() + + inline fun generateReplacementBlock(original: JcInst, blockGen: BlockGenerationContext.() -> Unit) { + val originalLocation = original.location + val ctx = BlockGenerationContext(originalLocation, generatedLocalVarIndex) + + ctx.blockGen() + + // add back jump from generated block + ctx.addInstruction { loc -> + JcGotoInst(loc, JcInstRef(originalLocation.index + 1)) + } + + // replace original instruction with jump to the generated block + val replacementBlockStart = JcInstRef(ctx.generatedLocations.first().index) + mutableInstructions[originalLocation.index] = JcGotoInst(originalLocation, replacementBlockStart) + + generatedLocalVarIndex = ctx.localVarIndex + + val generatedLocations = ctx.generatedLocations + modifiedLocationIndices[originalLocation.index] = generatedLocations.map { it.index } + } + + fun buildInstList(): JcInstList { + fixCatchBlockThrowers() + return JcInstListImpl(mutableInstructions) + } + + /** + * Since we generate multiple instructions instead of a single one, + * we must ensure that all catchers of the original instruction will + * catch exceptions of generated instructions. + * */ + private fun fixCatchBlockThrowers() { + for (i in mutableInstructions.indices) { + val instruction = mutableInstructions[i] + if (instruction !is JcCatchInst) continue + + val throwers = instruction.throwers.toMutableList() + for (throwerIdx in throwers.indices) { + val thrower = throwers[throwerIdx] + val generatedLocations = modifiedLocationIndices[thrower.index] ?: continue + generatedLocations.mapTo(throwers) { JcInstRef(it) } + } + + mutableInstructions[i] = with(instruction) { + JcCatchInst(location, throwable, throwableTypes, throwers) + } + } + } + + inner class BlockGenerationContext( + val originalLocation: JcInstLocation, + initialLocalVarIndex: Int, + ) { + var localVarIndex: Int = initialLocalVarIndex + val generatedLocations = mutableListOf() + + fun nextLocalVar(name: String, type: JcType) = JcLocalVar(localVarIndex++, name, type) + + @OptIn(ExperimentalContracts::class) + inline fun addInstruction(body: (JcInstLocation) -> JcInst) { + contract { + callsInPlace(body, InvocationKind.EXACTLY_ONCE) + } + + mutableInstructions.addInstruction(originalLocation) { loc -> + generatedLocations += loc + body(loc) + } + } + + fun replaceInstructionAtLocation(loc: JcInstLocation, replacement: (JcInst) -> JcInst) { + val currentInst = mutableInstructions[loc.index] + mutableInstructions[loc.index] = replacement(currentInst) + } + } + + inline fun MutableList.addInstruction(origin: JcInstLocation, body: (JcInstLocation) -> JcInst) { + val index = size + val newLocation = JcInstLocationImpl(origin.method, index, origin.lineNumber) + val instruction = body(newLocation) + check(size == index) + add(instruction) + } + + private object LocalVarMaxIndexFinder : JcExprVisitor.Default { + override fun defaultVisitJcExpr(expr: JcExpr) = find(expr.operands) + override fun visitJcLocalVar(value: JcLocalVar) = value.index + fun find(expressions: Iterable): Int = expressions.maxOfOrNull { it.accept(this) } ?: -1 + } +} diff --git a/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcStringConcatTransformer.kt b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcStringConcatTransformer.kt new file mode 100644 index 0000000000..658cb24acb --- /dev/null +++ b/usvm-jvm/src/main/kotlin/org/usvm/machine/interpreter/transformers/JcStringConcatTransformer.kt @@ -0,0 +1,242 @@ +package org.usvm.machine.interpreter.transformers + +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcInstExtFeature +import org.jacodb.api.jvm.JcMethod +import org.jacodb.api.jvm.JcPrimitiveType +import org.jacodb.api.jvm.JcRefType +import org.jacodb.api.jvm.JcType +import org.jacodb.api.jvm.JcTypedMethod +import org.jacodb.api.jvm.cfg.BsmArg +import org.jacodb.api.jvm.cfg.BsmDoubleArg +import org.jacodb.api.jvm.cfg.BsmFloatArg +import org.jacodb.api.jvm.cfg.BsmHandle +import org.jacodb.api.jvm.cfg.BsmIntArg +import org.jacodb.api.jvm.cfg.BsmLongArg +import org.jacodb.api.jvm.cfg.BsmMethodTypeArg +import org.jacodb.api.jvm.cfg.BsmStringArg +import org.jacodb.api.jvm.cfg.BsmTypeArg +import org.jacodb.api.jvm.cfg.JcAssignInst +import org.jacodb.api.jvm.cfg.JcDynamicCallExpr +import org.jacodb.api.jvm.cfg.JcInst +import org.jacodb.api.jvm.cfg.JcInstList +import org.jacodb.api.jvm.cfg.JcStaticCallExpr +import org.jacodb.api.jvm.cfg.JcStringConstant +import org.jacodb.api.jvm.cfg.JcValue +import org.jacodb.api.jvm.cfg.JcVirtualCallExpr +import org.jacodb.api.jvm.ext.objectType +import org.jacodb.impl.cfg.TypedStaticMethodRefImpl +import org.jacodb.impl.cfg.VirtualMethodRefImpl +import org.usvm.machine.interpreter.transformers.JcSingleInstructionTransformer.BlockGenerationContext + +object JcStringConcatTransformer : JcInstExtFeature { + private const val JAVA_STRING = "java.lang.String" + private const val STRING_CONCAT_FACTORY = "java.lang.invoke.StringConcatFactory" + private const val STRING_CONCAT_WITH_CONSTANTS = "makeConcatWithConstants" + + fun methodIsStringConcat(method: JcMethod): Boolean = + STRING_CONCAT_WITH_CONSTANTS == method.name && STRING_CONCAT_FACTORY == method.enclosingClass.name + + override fun transformInstList(method: JcMethod, list: JcInstList): JcInstList { + val stringConcatCalls = list.mapNotNull { inst -> + val assignInst = inst as? JcAssignInst ?: return@mapNotNull null + val invokeDynamicExpr = assignInst.rhv as? JcDynamicCallExpr ?: return@mapNotNull null + if (!methodIsStringConcat(invokeDynamicExpr.method.method)) return@mapNotNull null + assignInst to invokeDynamicExpr + } + + if (stringConcatCalls.isEmpty()) return list + + val stringType = method.enclosingClass.classpath + .findTypeOrNull(JAVA_STRING) as? JcClassType + ?: return list + + val stringConcatMethod = stringType.declaredMethods.singleOrNull { + !it.isStatic && it.name == "concat" && it.parameters.size == 1 + } ?: return list + + val stringConcatElements = stringConcatCalls.mapNotNull { (assign, expr) -> + val recipe = (expr.bsmArgs.lastOrNull() as? BsmStringArg)?.value ?: return@mapNotNull null + val elements = parseStringConcatRecipe( + stringType, recipe, expr.bsmArgs.dropLast(1).asReversed(), + expr.callSiteArgs, expr.callSiteArgTypes + ) ?: return@mapNotNull null + + assign to elements + } + + if (stringConcatElements.isEmpty()) return list + + val transformer = JcSingleInstructionTransformer(list) + for ((assignment, concatElements) in stringConcatElements) { + transformer.generateReplacementBlock(assignment) { + generateConcatBlock(stringType, stringConcatMethod, assignment.lhv, concatElements) + } + } + + return transformer.buildInstList() + } + + private fun BlockGenerationContext.generateConcatBlock( + stringType: JcClassType, + stringConcatMethod: JcTypedMethod, + resultVariable: JcValue, + elements: List + ) { + if (elements.isEmpty()) { + addInstruction { loc -> + JcAssignInst(loc, resultVariable, JcStringConstant("", stringType)) + } + return + } + + val elementsIter = elements.iterator() + var current = elementStringValue(stringType, elementsIter.next()) + while (elementsIter.hasNext()) { + val element = elementStringValue(stringType, elementsIter.next()) + current = generateStringConcat(stringType, stringConcatMethod, current, element) + } + + addInstruction { loc -> + JcAssignInst(loc, resultVariable, current) + } + } + + private fun BlockGenerationContext.elementStringValue( + stringType: JcClassType, + element: StringConcatElement + ): JcValue = when (element) { + is StringConcatElement.StringElement -> element.value + is StringConcatElement.OtherElement -> { + val value = nextLocalVar("str_val", stringType) + val methodRef = element.toStringTransformer.staticMethodRef() + val callExpr = JcStaticCallExpr(methodRef, listOf(element.value)) + addInstruction { loc -> + JcAssignInst(loc, value, callExpr) + } + value + } + } + + private fun BlockGenerationContext.generateStringConcat( + stringType: JcClassType, + stringConcatMethod: JcTypedMethod, + first: JcValue, + second: JcValue + ): JcValue { + val value = nextLocalVar("str", stringType) + val methodRef = stringConcatMethod.virtualMethodRef(stringType) + val callExpr = JcVirtualCallExpr(methodRef, first, listOf(second)) + addInstruction { loc -> + JcAssignInst(loc, value, callExpr) + } + return value + } + + private fun JcTypedMethod.virtualMethodRef(stringType: JcClassType) = + VirtualMethodRefImpl.of(stringType, this) + + private fun JcTypedMethod.staticMethodRef() = TypedStaticMethodRefImpl( + enclosingType as JcClassType, + name, + method.parameters.map { it.type }, + method.returnType + ) + + private sealed interface StringConcatElement { + data class StringElement(val value: JcValue) : StringConcatElement + data class OtherElement(val value: JcValue, val toStringTransformer: JcTypedMethod) : StringConcatElement + } + + private fun parseStringConcatRecipe( + stringType: JcClassType, + recipe: String, + bsmArgs: List, + callArgs: List, + callArgTypes: List + ): List? { + val elements = mutableListOf() + + val acc = StringBuilder() + + var constCount = 0 + var argsCount = 0 + + for (recipeCh in recipe) { + when (recipeCh) { + '\u0002' -> { + // Accumulate constant args along with any constants encoded + // into the recipe + val constant = bsmArgs.getOrNull(constCount++) ?: return null + + val constantValue = when (constant) { + is BsmDoubleArg -> constant.value.toString() + is BsmFloatArg -> constant.value.toString() + is BsmIntArg -> constant.value.toString() + is BsmLongArg -> constant.value.toString() + is BsmStringArg -> constant.value + is BsmHandle, + is BsmMethodTypeArg, + is BsmTypeArg -> return null + } + + acc.append(constantValue) + } + + '\u0001' -> { + // Flush any accumulated characters into a constant + if (acc.isNotEmpty()) { + elements += StringConcatElement.StringElement( + JcStringConstant(acc.toString(), stringType) + ) + acc.setLength(0) + } + + val argValue = callArgs.getOrNull(argsCount) ?: return null + val valueType = callArgTypes.getOrNull(argsCount) ?: return null + argsCount++ + + val argElement = valueStringElement(stringType, argValue, valueType) ?: return null + elements.add(argElement) + } + + else -> { + // Not a special character, this is a constant embedded into + // the recipe itself. + acc.append(recipeCh) + } + } + } + + // Flush the remaining characters as constant: + if (acc.isNotEmpty()) { + elements += StringConcatElement.StringElement( + JcStringConstant(acc.toString(), stringType) + ) + } + + return elements + } + + private fun valueStringElement(stringType: JcClassType, value: JcValue, valueType: JcType): StringConcatElement? = + when (valueType) { + is JcPrimitiveType -> { + val valueOfMethod = stringType.findValueOfMethod(valueType) + valueOfMethod?.let { StringConcatElement.OtherElement(value, it) } + } + + stringType -> StringConcatElement.StringElement(value) + + is JcRefType -> { + val valueOfMethod = stringType.findValueOfMethod(stringType.classpath.objectType) + valueOfMethod?.let { StringConcatElement.OtherElement(value, it) } + } + + else -> null + } + + private fun JcClassType.findValueOfMethod(argumentType: JcType): JcTypedMethod? = + declaredMethods.singleOrNull { + it.isStatic && it.name == "valueOf" && it.parameters.size == 1 && it.parameters.first().type == argumentType + } +} diff --git a/usvm-jvm/src/samples-jdk11/java/org/usvm/samples/strings/StringConcatSamples.java b/usvm-jvm/src/samples-jdk11/java/org/usvm/samples/strings/StringConcatSamples.java new file mode 100644 index 0000000000..a49791f59a --- /dev/null +++ b/usvm-jvm/src/samples-jdk11/java/org/usvm/samples/strings/StringConcatSamples.java @@ -0,0 +1,26 @@ +package org.usvm.samples.strings; + +public class StringConcatSamples { + public static class Bar { + @Override + public String toString() { + return "Bar"; + } + } + + public boolean stringConcatEq() { + Bar bar = new Bar(); + int intValue = 17; + boolean boolValue = false; + String concatenated = "prefix_" + intValue + "_" + boolValue + "_" + bar + "_suffix"; + String expected = "prefix_17_false_Bar_suffix"; + return expected.equals(concatenated); + } + + public boolean stringConcatStrangeEq() { + int iv = 0; + String expected = "\u0000" + 0 + "#" + 0 + "\u0001" + 0 + "!\u0002" + 0 + "@\u0012\t"; + String concatenated = "\u0000" + iv + "#" + iv + "\u0001" + iv + "!\u0002" + iv + "@\u0012\t"; + return expected.equals(concatenated); + } +} diff --git a/usvm-jvm/src/test/kotlin/org/usvm/samples/JacoDBContainer.kt b/usvm-jvm/src/test/kotlin/org/usvm/samples/JacoDBContainer.kt index abc6c3644f..55b5852cf5 100644 --- a/usvm-jvm/src/test/kotlin/org/usvm/samples/JacoDBContainer.kt +++ b/usvm-jvm/src/test/kotlin/org/usvm/samples/JacoDBContainer.kt @@ -7,7 +7,8 @@ import org.jacodb.approximation.Approximations import org.jacodb.impl.JcSettings import org.jacodb.impl.features.InMemoryHierarchy import org.jacodb.impl.jacodb -import org.usvm.machine.interpreter.JcMultiDimArrayAllocationTransformer +import org.usvm.machine.interpreter.transformers.JcMultiDimArrayAllocationTransformer +import org.usvm.machine.interpreter.transformers.JcStringConcatTransformer import org.usvm.util.classpathWithApproximations import java.io.File @@ -31,10 +32,15 @@ class JacoDBContainer( loadByteCode(classpath) } + val features = listOf( + JcMultiDimArrayAllocationTransformer, + JcStringConcatTransformer, + ) + val cp = if (samplesWithApproximationsKey == key) { - db.classpathWithApproximations(classpath, listOf(JcMultiDimArrayAllocationTransformer)) + db.classpathWithApproximations(classpath, features) } else { - db.classpath(classpath, listOf(JcMultiDimArrayAllocationTransformer)) + db.classpath(classpath, features) } db to cp } diff --git a/usvm-jvm/src/test/kotlin/org/usvm/samples/strings/StringConcatSamplesTest.kt b/usvm-jvm/src/test/kotlin/org/usvm/samples/strings/StringConcatSamplesTest.kt new file mode 100644 index 0000000000..1b3dc884f6 --- /dev/null +++ b/usvm-jvm/src/test/kotlin/org/usvm/samples/strings/StringConcatSamplesTest.kt @@ -0,0 +1,54 @@ +package org.usvm.samples.strings + +import org.jacodb.api.jvm.JcClassType +import org.jacodb.api.jvm.JcTypedMethod +import org.usvm.api.JcTest +import org.usvm.api.util.JcTestInterpreter +import org.usvm.machine.JcMachine +import org.usvm.samples.JavaMethodTestRunner +import org.usvm.util.JcTestExecutor +import org.usvm.util.JcTestResolverType +import kotlin.test.Test +import kotlin.test.assertEquals +import kotlin.test.assertTrue + +class StringConcatSamplesTest : JavaMethodTestRunner() { + @Test + fun testStringConcatEq() { + val method = findMethod("stringConcatEq") + val states = executeMethod(method) + val successStates = states.filter { it.result.isSuccess } + assertEquals(1, successStates.size) + assertTrue(successStates.single().result.getOrThrow() as Boolean) + } + + @Test + fun testStringConcatStrangeEq() { + val method = findMethod("stringConcatStrangeEq") + val states = executeMethod(method) + val successStates = states.filter { it.result.isSuccess } + assertEquals(1, successStates.size) + assertTrue(successStates.single().result.getOrThrow() as Boolean) + } + + private fun executeMethod(method: JcTypedMethod): List { + val testResolver = when (resolverType) { + JcTestResolverType.INTERPRETER -> JcTestInterpreter() + JcTestResolverType.CONCRETE_EXECUTOR -> JcTestExecutor(classpath = cp) + } + + return JcMachine(cp, options).use { machine -> + val states = machine.analyze(method.method) + states.map { testResolver.resolve(method, it) } + } + } + + private fun findMethod(methodName: String): JcTypedMethod = + (cp.findTypeOrNull(SAMPLES_CLASS) as? JcClassType) + ?.declaredMethods?.singleOrNull { it.name == methodName } + ?: error("Cannot find method $methodName in $SAMPLES_CLASS") + + companion object { + const val SAMPLES_CLASS = "org.usvm.samples.strings.StringConcatSamples" + } +}