From 70d0b1078c32753631fa3933b7851c270305c84f Mon Sep 17 00:00:00 2001 From: Mathias Morbitzer Date: Thu, 19 Dec 2024 10:43:07 +0100 Subject: [PATCH] better function summaries processing --- .../aisec/cpg/passes/PointsToPass.kt | 156 +++++++++--------- .../aisec/cpg/passes/PointsToPassTest.kt | 8 +- 2 files changed, 83 insertions(+), 81 deletions(-) diff --git a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPass.kt b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPass.kt index 9d4da37aa3..b7e114f21e 100644 --- a/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPass.kt +++ b/cpg-core/src/main/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPass.kt @@ -273,40 +273,95 @@ class PointsToPass(ctx: TranslationContext) : EOGStarterPass(ctx, orderDependenc } } - // TODO: Replace this condition by collecting only invokes with function summary below. - if (currentNode.invokes.all { ctx.config.functionSummaries.hasSummary(it) }) { - // We have a FunctionSummary. Set the new values for the arguments. Push the - // values of the arguments and return value after executing the function call to our - // doubleState. - - // First, collect all writes to all parameters - val changedParams = mutableMapOf>>() - currentNode.invokes.forEach { fd -> + val destinations = identitySetOf() + val sources = identitySetOf() + val changedParams = mutableMapOf>>() + + // First, collect all writes to all parameters + currentNode.invokes + .filter { it.functionSummary.isNotEmpty() } + .forEach { fd -> + // We have a FunctionSummary. Set the new values for the arguments. Push the + // values of the arguments and return value after executing the function call to our + // doubleState. val tmp = ctx.config.functionSummaries.getLastWrites(fd) for ((k, v) in tmp) { changedParams.computeIfAbsent(k) { mutableSetOf() }.addAll(v) } } - for ((param, newValues) in changedParams) { + for ((param, newValues) in changedParams) { + when (param) { + is ParameterDeclaration -> + if (param.argumentIndex < currentNode.arguments.size) { + // Dereference the parameter + destinations.addAll( + doubleState.getValues(currentNode.arguments[param.argumentIndex]) + ) + } + is ReturnStatement -> destinations.add(currentNode) + } + newValues.forEach { (value, derefSource) -> + when (value) { + is ParameterDeclaration -> + // Add the value of the respective argument in the CallExpression + // Only dereference the parameter when we stored that in the + // functionSummary + if (value.argumentIndex < currentNode.arguments.size) { + if (derefSource) { + doubleState + .getValues(currentNode.arguments[value.argumentIndex]) + .forEach { sources.addAll(doubleState.getValues(it)) } + } else { + sources.add(currentNode.arguments[value.argumentIndex]) + } + } + is ParameterMemoryValue -> { + // In case the FunctionSummary says that we have to use the + // dereferenced value here, we look up the argument, dereference it, + // and then add it to the sources + if (value.name.localName == "derefvalue") { + val p = + currentNode.invokes + .flatMap { it.parameters } + .filter { it.name == value.name.parent } + p.forEach { + if (it.argumentIndex < currentNode.arguments.size) { + val arg = currentNode.arguments[it.argumentIndex] + sources.addAll( + doubleState.getValues(arg).flatMap { + doubleState.getValues(it) + } + ) + } + } + } + } + else -> sources.add(value) + } + } + /*// Ignore the ReturnStatements here, we use them when handling AssignExpressions + if (param !is ReturnStatement) { val destinations = when (param) { is ParameterDeclaration -> // Dereference the parameter if (param.argumentIndex < currentNode.arguments.size) { - doubleState.getValues(currentNode.arguments[param.argumentIndex]) + doubleState.getValues( + currentNode.arguments[param.argumentIndex] + ) } else null - is ReturnStatement -> identitySetOf(currentNode) else -> null } val sources = mutableSetOf() newValues.forEach { (value, derefSource) -> when (value) { is ParameterDeclaration -> - // Add the value of the respective argument in the CallExpression - // Only dereference the parameter when we stored that in the - // functionSummary if (value.argumentIndex < currentNode.arguments.size) { + // Add the value of the respective argument in the + // CallExpression + // Only dereference the parameter when we stored that in the + // functionSummary if (derefSource) { doubleState .getValues(currentNode.arguments[value.argumentIndex]) @@ -325,7 +380,7 @@ class PointsToPass(ctx: TranslationContext) : EOGStarterPass(ctx, orderDependenc .flatMap { it.parameters } .filter { it.name == value.name.parent } p.forEach { - if (it.argumentIndex < currentNode.arguments.size) { + if (value.argumentIndex < currentNode.arguments.size) { val arg = currentNode.arguments[it.argumentIndex] sources.addAll( doubleState.getValues(arg).flatMap { @@ -342,65 +397,10 @@ class PointsToPass(ctx: TranslationContext) : EOGStarterPass(ctx, orderDependenc if (destinations != null && sources.isNotEmpty()) { doubleState = doubleState.updateValues(sources, destinations) } - // Ignore the ReturnStatements here, we use them when handling AssignExpressions - if (param !is ReturnStatement) { - val destinations = - when (param) { - is ParameterDeclaration -> - // Dereference the parameter - if (param.argumentIndex < currentNode.arguments.size) { - doubleState.getValues( - currentNode.arguments[param.argumentIndex] - ) - } else null - else -> null - } - val sources = mutableSetOf() - newValues.forEach { (value, derefSource) -> - when (value) { - is ParameterDeclaration -> - if (value.argumentIndex < currentNode.arguments.size) { - // Add the value of the respective argument in the - // CallExpression - // Only dereference the parameter when we stored that in the - // functionSummary - if (derefSource) { - doubleState - .getValues(currentNode.arguments[value.argumentIndex]) - .forEach { sources.addAll(doubleState.getValues(it)) } - } else { - sources.add(currentNode.arguments[value.argumentIndex]) - } - } - is ParameterMemoryValue -> { - // In case the FunctionSummary says that we have to use the - // dereferenced value here, we look up the argument, dereference it, - // and then add it to the sources - if (value.name.localName == "derefvalue") { - val p = - currentNode.invokes - .flatMap { it.parameters } - .filter { it.name == value.name.parent } - p.forEach { - if (value.argumentIndex < currentNode.arguments.size) { - val arg = currentNode.arguments[it.argumentIndex] - sources.addAll( - doubleState.getValues(arg).flatMap { - doubleState.getValues(it) - } - ) - } - } - } - } - else -> sources.add(value) - } - } - if (destinations != null && sources.isNotEmpty()) { - doubleState = doubleState.updateValues(sources, destinations) - } - } - } + }*/ + } + if (destinations.isNotEmpty() && sources.isNotEmpty()) { + doubleState = doubleState.updateValues(sources, destinations) } return doubleState @@ -542,7 +542,7 @@ class PointsToPass(ctx: TranslationContext) : EOGStarterPass(ctx, orderDependenc } doubleState = doubleState.pushToDeclarationsState(param.memoryValue, paramDerefState) - doubleState = doubleState.push(param, paramState) + // doubleState = doubleState.push(param, paramState) } return doubleState } @@ -694,7 +694,9 @@ class PointsToPass(ctx: TranslationContext) : EOGStarterPass(ctx, orderDependenc identitySetOf(UnknownMemoryValue(node.name)) } val retVal = identitySetOf() - inputVal.forEach { retVal.addAll(this.getValues(it)) } + inputVal.forEach { + retVal.addAll(/*this.getValues(it)*/ fetchElementFromDeclarationState(it)) + } retVal } is Declaration -> { diff --git a/cpg-language-cxx/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPassTest.kt b/cpg-language-cxx/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPassTest.kt index b087c0f4d7..3fec2a5193 100644 --- a/cpg-language-cxx/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPassTest.kt +++ b/cpg-language-cxx/src/test/kotlin/de/fraunhofer/aisec/cpg/passes/PointsToPassTest.kt @@ -1107,8 +1107,8 @@ class PointsToPassTest { // Line 159 assertEquals(1, local_20Line159.prevDFG.size) - assertEquals(1, param_1Line145.prevDFG.size) - assertEquals(param_1Line145.prevDFG.first(), local_20Line159.prevDFG.first()) + assertEquals(1, param_1Line159.prevDFG.size) + assertEquals(param_1Line159.prevDFG.first(), local_20Line159.prevDFG.first()) // Effect from Line 160 assertEquals(1, local_30Line165.prevDFG.size) @@ -1129,8 +1129,8 @@ class PointsToPassTest { assertEquals(ceLine172, local_28Line172.prevDFG.firstOrNull()) // Line 177 TODO: What do we want to check here? - /* assertEquals(1, local_28Line177.prevDFG.size) - assertEquals(local_10Line172, local_28Line177.prevDFG.firstOrNull())*/ + assertEquals(1, local_28Line177.prevDFG.size) + assertEquals(local_10Line172, local_28Line177.prevDFG.firstOrNull()) // Line 179 assertEquals(2, local_28Line179.prevDFG.size)