diff --git a/core-codemods/src/test/java/io/codemodder/codemods/MoveSwitchDefaultCaseLastCodemodTest.java b/core-codemods/src/test/java/io/codemodder/codemods/MoveSwitchDefaultCaseLastCodemodTest.java index a10566269..f99aca765 100644 --- a/core-codemods/src/test/java/io/codemodder/codemods/MoveSwitchDefaultCaseLastCodemodTest.java +++ b/core-codemods/src/test/java/io/codemodder/codemods/MoveSwitchDefaultCaseLastCodemodTest.java @@ -57,9 +57,9 @@ void foo() { break; default: break; - case 0: - break; - } + case 0: + break; + } } } """; diff --git a/core-codemods/src/test/java/io/codemodder/codemods/SQLParameterizerCodemodTest.java b/core-codemods/src/test/java/io/codemodder/codemods/SQLParameterizerCodemodTest.java index 9b0a69455..88a5c5406 100644 --- a/core-codemods/src/test/java/io/codemodder/codemods/SQLParameterizerCodemodTest.java +++ b/core-codemods/src/test/java/io/codemodder/codemods/SQLParameterizerCodemodTest.java @@ -2,9 +2,21 @@ import io.codemodder.testutils.CodemodTestMixin; import io.codemodder.testutils.Metadata; +import org.junit.jupiter.api.Nested; -@Metadata( - codemodType = SQLParameterizerCodemod.class, - testResourceDir = "sql-parameterizer", - dependencies = {}) -final class SQLParameterizerCodemodTest implements CodemodTestMixin {} +final class SQLParameterizerCodemodTest { + + @Nested + @Metadata( + codemodType = SQLParameterizerCodemod.class, + testResourceDir = "sql-parameterizer/defaultTransformation", + dependencies = {}) + class DefaultTransformationTest implements CodemodTestMixin {} + + @Nested + @Metadata( + codemodType = SQLParameterizerCodemod.class, + testResourceDir = "sql-parameterizer/hijackTransformation", + dependencies = {}) + class HijackTransformationTest implements CodemodTestMixin {} +} diff --git a/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionChallenge/SqlInjectionChallenge.java.after b/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionChallenge/SqlInjectionChallenge.java.after index eb7086476..f74fbc82d 100644 --- a/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionChallenge/SqlInjectionChallenge.java.after +++ b/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionChallenge/SqlInjectionChallenge.java.after @@ -68,6 +68,7 @@ public class SqlInjectionChallenge extends AssignmentEndpoint { "select userid from sql_challenge_users where userid = ?"; PreparedStatement statement = connection.prepareStatement(checkUserQuery); statement.setString(1, username_reg); + ResultSet resultSet = statement.execute(); if (resultSet.next()) { if (username_reg.contains("tom'")) { @@ -84,7 +85,6 @@ public class SqlInjectionChallenge extends AssignmentEndpoint { preparedStatement.execute(); attackResult = success(this).feedback("user.created").feedbackArgs(username_reg).build(); } - } catch (SQLException e) { attackResult = failed(this).output("Something went wrong").build(); } diff --git a/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionLesson8/SqlInjectionLesson8.java.after b/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionLesson8/SqlInjectionLesson8.java.after index d6c77cb58..efde86db4 100644 --- a/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionLesson8/SqlInjectionLesson8.java.after +++ b/core-codemods/src/test/resources/defectdojo-sql-injection/SqlInjectionLesson8/SqlInjectionLesson8.java.after @@ -70,9 +70,10 @@ public class SqlInjectionLesson8 extends AssignmentEndpoint { try { PreparedStatement statement = connection.prepareStatement( -query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATABLE); + query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATABLE); log(connection, query); statement.setString(1, name); + statement.setString(2, auth_tan); ResultSet results = statement.execute(); if (results.getStatement() != null) { @@ -98,7 +99,6 @@ query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDAT } else { return failed(this).build(); } - } catch (SQLException e) { return failed(this) .output("
" + e.getMessage() + "") @@ -156,7 +156,7 @@ query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDAT statement.setString(1, sdf.format(cal.getTime())); statement.setString(2, action); statement.execute(); - } catch (SQLException e) { + } catch (SQLException e) { System.err.println(e.getMessage()); } } diff --git a/core-codemods/src/test/resources/jexl-expression-injection/Test.java.after b/core-codemods/src/test/resources/jexl-expression-injection/Test.java.after index f4d80f6bb..9f0543d61 100644 --- a/core-codemods/src/test/resources/jexl-expression-injection/Test.java.after +++ b/core-codemods/src/test/resources/jexl-expression-injection/Test.java.after @@ -27,7 +27,6 @@ public final class Test { JexlExpression expression = jexl.createExpression(input); JexlContext context = new MapContext(); expression.evaluate(context); - } } @@ -42,7 +41,6 @@ public final class Test { sandbox.block(cls); } new JexlBuilder().sandbox(sandbox).create().createExpression(input).evaluate(context); - } } diff --git a/core-codemods/src/test/resources/move-switch-default-last/Test.java.after b/core-codemods/src/test/resources/move-switch-default-last/Test.java.after index 6dd475c2d..75c89e56f 100644 --- a/core-codemods/src/test/resources/move-switch-default-last/Test.java.after +++ b/core-codemods/src/test/resources/move-switch-default-last/Test.java.after @@ -11,8 +11,8 @@ final class Test { case "bar": System.out.println("bar"); break; -default: - System.out.println("default"); } + default: + System.out.println("default");} System.out.println("bar"); } diff --git a/core-codemods/src/test/resources/semgrep-sql-injection-formatted-sql-string/SqlInjectionLesson5a.java.after b/core-codemods/src/test/resources/semgrep-sql-injection-formatted-sql-string/SqlInjectionLesson5a.java.after index a83b4902a..781e920e4 100644 --- a/core-codemods/src/test/resources/semgrep-sql-injection-formatted-sql-string/SqlInjectionLesson5a.java.after +++ b/core-codemods/src/test/resources/semgrep-sql-injection-formatted-sql-string/SqlInjectionLesson5a.java.after @@ -64,8 +64,9 @@ public class SqlInjectionLesson5a extends AssignmentEndpoint { "SELECT * FROM user_data WHERE first_name = 'John' and last_name = ?"; try (PreparedStatement statement = connection.prepareStatement( -query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATABLE)) { + query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATABLE)) { statement.setString(1, accountName); + ResultSet results = statement.execute(); if ((results != null) && (results.first())) { ResultSetMetaData resultsMetaData = results.getMetaData(); @@ -90,7 +91,6 @@ query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATAB .output("Your query was: " + query) .build(); } - } catch (SQLException sqle) { return failed(this).output(sqle.getMessage() + "
Your query was: " + query).build(); } @@ -135,4 +135,4 @@ query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATAB t.append("

"); return (t.toString()); } -} \ No newline at end of file +} diff --git a/core-codemods/src/test/resources/semgrep-sql-injection/SqlInjectionLesson8.java.after b/core-codemods/src/test/resources/semgrep-sql-injection/SqlInjectionLesson8.java.after index d6c77cb58..efde86db4 100644 --- a/core-codemods/src/test/resources/semgrep-sql-injection/SqlInjectionLesson8.java.after +++ b/core-codemods/src/test/resources/semgrep-sql-injection/SqlInjectionLesson8.java.after @@ -70,9 +70,10 @@ public class SqlInjectionLesson8 extends AssignmentEndpoint { try { PreparedStatement statement = connection.prepareStatement( -query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATABLE); + query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDATABLE); log(connection, query); statement.setString(1, name); + statement.setString(2, auth_tan); ResultSet results = statement.execute(); if (results.getStatement() != null) { @@ -98,7 +99,6 @@ query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDAT } else { return failed(this).build(); } - } catch (SQLException e) { return failed(this) .output("
" + e.getMessage() + "") @@ -156,7 +156,7 @@ query, ResultSet.TYPE_SCROLL_INSENSITIVE, ResultSet.CONCUR_UPDAT statement.setString(1, sdf.format(cal.getTime())); statement.setString(2, action); statement.execute(); - } catch (SQLException e) { + } catch (SQLException e) { System.err.println(e.getMessage()); } } diff --git a/core-codemods/src/test/resources/sonar-sql-injection-s2077/supported/SqlInjectionChallenge.java.after b/core-codemods/src/test/resources/sonar-sql-injection-s2077/supported/SqlInjectionChallenge.java.after index eb7086476..f74fbc82d 100644 --- a/core-codemods/src/test/resources/sonar-sql-injection-s2077/supported/SqlInjectionChallenge.java.after +++ b/core-codemods/src/test/resources/sonar-sql-injection-s2077/supported/SqlInjectionChallenge.java.after @@ -68,6 +68,7 @@ public class SqlInjectionChallenge extends AssignmentEndpoint { "select userid from sql_challenge_users where userid = ?"; PreparedStatement statement = connection.prepareStatement(checkUserQuery); statement.setString(1, username_reg); + ResultSet resultSet = statement.execute(); if (resultSet.next()) { if (username_reg.contains("tom'")) { @@ -84,7 +85,6 @@ public class SqlInjectionChallenge extends AssignmentEndpoint { preparedStatement.execute(); attackResult = success(this).feedback("user.created").feedbackArgs(username_reg).build(); } - } catch (SQLException e) { attackResult = failed(this).output("Something went wrong").build(); } diff --git a/core-codemods/src/test/resources/sql-parameterizer/Test.java.after b/core-codemods/src/test/resources/sql-parameterizer/defaultTransformation/Test.java.after similarity index 99% rename from core-codemods/src/test/resources/sql-parameterizer/Test.java.after rename to core-codemods/src/test/resources/sql-parameterizer/defaultTransformation/Test.java.after index 160311c2d..3cb08a594 100644 --- a/core-codemods/src/test/resources/sql-parameterizer/Test.java.after +++ b/core-codemods/src/test/resources/sql-parameterizer/defaultTransformation/Test.java.after @@ -23,7 +23,7 @@ public final class Test { stmt.setString(1, input); var rs = stmt.execute(); return rs; - } + } public ResultSet nameConflict(String input) throws SQLException { int stmt = 0; @@ -33,7 +33,7 @@ public final class Test { ResultSet rs = statement.execute(); stmt++; return rs; - } + } public ResultSet doubleNameConflict(String input) throws SQLException { int stmt = 0; @@ -44,7 +44,7 @@ public final class Test { ResultSet rs = stmt1.execute(); stmt = stmt + statement; return rs; - } + } public ResultSet tryResource(String input) throws SQLException { String sql = "SELECT * FROM USERS WHERE USER = ?"; @@ -62,7 +62,7 @@ public final class Test { stmt.setString(1, "user_" + input + "_name"); stmt.setString(2, input2); return stmt.execute(); - } + } public ResultSet referencesAfterExecute(String input) throws SQLException { String sql = "SELECT * FROM USERS WHERE USER = ?"; diff --git a/core-codemods/src/test/resources/sql-parameterizer/Test.java.before b/core-codemods/src/test/resources/sql-parameterizer/defaultTransformation/Test.java.before similarity index 100% rename from core-codemods/src/test/resources/sql-parameterizer/Test.java.before rename to core-codemods/src/test/resources/sql-parameterizer/defaultTransformation/Test.java.before diff --git a/core-codemods/src/test/resources/sql-parameterizer/hijackTransformation/Test.java.after b/core-codemods/src/test/resources/sql-parameterizer/hijackTransformation/Test.java.after new file mode 100644 index 000000000..d0d749c29 --- /dev/null +++ b/core-codemods/src/test/resources/sql-parameterizer/hijackTransformation/Test.java.after @@ -0,0 +1,33 @@ +package com.acme.testcode; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +public final class Test { + + private Connection conn; + + public void queryAfterDeclaration() throws SQLException { + Statement stmt; + String query2 = "SELECT * FROM users WHERE username = ?"; + PreparedStatement statement = conn.prepareStatement(query2); + statement.setString(1, request.getParameter("username")); + ResultSet rs2 = statement.execute(); + stmt = statement; + while (rs2.next()) { + System.out.println("User: " + rs2.getString("username")); + } + String query3 = "SELECT * FROM users WHERE email = ?"; + stmt.close(); + PreparedStatement stmt1 = conn.prepareStatement(query3); + stmt1.setString(1, request.getParameter("email")); + ResultSet rs3 = stmt1.execute(); + stmt = stmt1; + while (rs3.next()) { + System.out.println("User: " + rs3.getString("username")); + } + } +} diff --git a/core-codemods/src/test/resources/sql-parameterizer/hijackTransformation/Test.java.before b/core-codemods/src/test/resources/sql-parameterizer/hijackTransformation/Test.java.before new file mode 100644 index 000000000..ad479a91d --- /dev/null +++ b/core-codemods/src/test/resources/sql-parameterizer/hijackTransformation/Test.java.before @@ -0,0 +1,27 @@ +package com.acme.testcode; + +import java.sql.Connection; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +public final class Test { + + private Connection conn; + + public void queryAfterDeclaration() throws SQLException { + Statement stmt = conn.createStatement(); + String username = request.getParameter("username"); + String query2 = "SELECT * FROM users WHERE username = '" + username + "'"; + ResultSet rs2 = stmt.executeQuery(query2); + while (rs2.next()) { + System.out.println("User: " + rs2.getString("username")); + } + String email = request.getParameter("email"); + String query3 = "SELECT * FROM users WHERE email = '" + email + "'"; + ResultSet rs3 = stmt.executeQuery(query3); + while (rs3.next()) { + System.out.println("User: " + rs3.getString("username")); + } + } +} diff --git a/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTTransforms.java b/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTTransforms.java index e36c79bea..efc62cfcb 100644 --- a/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTTransforms.java +++ b/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTTransforms.java @@ -26,7 +26,6 @@ import com.github.javaparser.resolution.types.ResolvedType; import java.util.ArrayList; import java.util.Optional; -import java.util.stream.IntStream; public final class ASTTransforms { /** Add an import in alphabetical order. */ @@ -87,26 +86,21 @@ public static void addStaticImportIfMissing(final CompilationUnit cu, final Stri */ public static void addStatementAt( final NodeWithStatements node, final Statement stmt, final int index) { + + var oldStatements = node.getStatements(); var newStatements = new ArrayList(); int i = 0; - for (var s : node.getStatements()) { + for (var s : oldStatements) { if (i == index) { newStatements.add(stmt); } newStatements.add(s); i++; } - - // rebuilds the whole statements list as to preserve proper children order. - - // workaround for maintaining indent, removes all but the first - IntStream.range(0, node.getStatements().size() - 1) - .forEach(j -> node.getStatements().removeLast()); - // replace the first with the new statement if needed - if (index == 0) { - node.getStatements().get(0).replace(stmt); + for (i = index; i < oldStatements.size(); i++) { + node.setStatement(i, newStatements.get(i)); } - newStatements.stream().skip(1).forEach(node.getStatements()::add); + node.addStatement(newStatements.get(newStatements.size() - 1)); } /** @@ -291,7 +285,13 @@ public static Expression removeEmptyStringConcatenation(final BinaryExpr binexp) public static void removeEmptyStringConcatenation(Node subtree) { subtree .findAll(BinaryExpr.class, Node.TreeTraversal.POSTORDER) - .forEach(binexp -> binexp.replace(removeEmptyStringConcatenation(binexp))); + .forEach( + binexp -> { + var newNode = removeEmptyStringConcatenation(binexp); + if (newNode != binexp) { + binexp.replace(newNode.clone()); + } + }); } /** Removes unused variables. */ diff --git a/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTs.java b/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTs.java index 40cd8d471..7bbf5108f 100644 --- a/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTs.java +++ b/framework/codemodder-base/src/main/java/io/codemodder/ast/ASTs.java @@ -186,11 +186,11 @@ public static Optional isScopeInMethodCall(final Expression expr * * @return A tuple with the above pattern in order sans the {@link SimpleName}. */ - public static Optional isPatternExprDeclarationOf( + public static Optional isPatternExprDeclarationOf( final Node node, final String name) { - if (node instanceof PatternExpr) { - var pexpr = (PatternExpr) node; - if (pexpr.getNameAsString().equals(name)) return Optional.of(pexpr); + if (node instanceof TypePatternExpr) { + var pexpr = (TypePatternExpr) node; + if (pexpr.getName().asString().equals(name)) return Optional.of(pexpr); } return Optional.empty(); } diff --git a/framework/codemodder-base/src/main/java/io/codemodder/remediation/sqlinjection/SQLParameterizer.java b/framework/codemodder-base/src/main/java/io/codemodder/remediation/sqlinjection/SQLParameterizer.java index 916b25e86..2e931c363 100644 --- a/framework/codemodder-base/src/main/java/io/codemodder/remediation/sqlinjection/SQLParameterizer.java +++ b/framework/codemodder-base/src/main/java/io/codemodder/remediation/sqlinjection/SQLParameterizer.java @@ -8,6 +8,7 @@ import com.github.javaparser.ast.expr.*; import com.github.javaparser.ast.expr.BinaryExpr.Operator; import com.github.javaparser.ast.stmt.ExpressionStmt; +import com.github.javaparser.ast.stmt.Statement; import io.codemodder.Either; import io.codemodder.ast.ASTTransforms; import io.codemodder.ast.ASTs; @@ -69,11 +70,8 @@ public static boolean isParameterizationCandidate(final MethodCallExpr methodCal final Predicate isFirstArgumentNotSLE = n -> n.getArguments().getFirst().map(e -> !(e instanceof StringLiteralExpr)).orElse(false); - // is execute of an statement object whose first argument is not a string? - if (isExecute.and(hasScopeSQLStatement.and(isFirstArgumentNotSLE)).test(methodCallExpr)) { - return true; - } - return false; + // is an `execute*()` call of a statement object whose first argument is not a string? + return isExecute.and(hasScopeSQLStatement.and(isFirstArgumentNotSLE)).test(methodCallExpr); // Thrown by the JavaParser Symbol Solver when it can't resolve types } catch (RuntimeException e) { @@ -94,22 +92,6 @@ public static Set fixableJdbcMethodNames() { private static final Set fixableJdbcMethodNames = Set.of("executeQuery", "execute", "executeLargeUpdate", "executeUpdate"); - /** - * Tries to find the source of an expression if it can be uniquely defined, otherwise, returns - * self. - */ - public static Expression resolveExpression(final Expression expr) { - return Optional.of(expr) - .map(e -> e instanceof NameExpr ? e.asNameExpr() : null) - .flatMap(n -> ASTs.findEarliestLocalDeclarationOf(n.getName())) - .map(s -> s instanceof LocalVariableDeclaration ? (LocalVariableDeclaration) s : null) - // TODO currently it assumes it is never assigned, add support for definite assignments here - .filter(ASTs::isFinalOrNeverAssigned) - .flatMap(lvd -> lvd.getVariableDeclarator().getInitializer()) - .map(SQLParameterizer::resolveExpression) - .orElse(expr); - } - private Optional isConnectionCreateStatement(final Expression expr) { final Predicate isConnection = e -> { @@ -119,12 +101,13 @@ private Optional isConnectionCreateStatement(final Expression ex return false; } }; + var stmtCreationMethods = List.of("createStatement", "prepareStatement"); return Optional.of(expr) .map(e -> e instanceof MethodCallExpr ? expr.asMethodCallExpr() : null) .filter( mce -> mce.getScope().filter(isConnection).isPresent() - && mce.getNameAsString().equals("createStatement")); + && (stmtCreationMethods.contains(mce.getNameAsString()))); } private Optional validateExecuteCall(final MethodCallExpr executeCall) { @@ -205,16 +188,17 @@ private Optional validateExecuteCall(final MethodCallExpr execut .map(expr -> expr instanceof NameExpr ? expr.asNameExpr() : null) .flatMap(ne -> ASTs.findEarliestLocalVariableDeclarationOf(ne, ne.getNameAsString())); - // Has a single assignment - // We erroniously assume that it always shadows the init expression // Needs some flow analysis to correctly address this case final Optional maybeSingleAssigned = maybeLVD .map(lvd -> ASTs.findAllAssignments(lvd).limit(2).toList()) - .filter(allAssignments -> allAssignments.size() == 1) - .map(allAssignments -> allAssignments.get(0)) + .filter(allAssignments -> !allAssignments.isEmpty()) + .map(allAssignments -> allAssignments.get(allAssignments.size() - 1)) .filter(assign -> assign.getTarget().isNameExpr()) - .filter(assign -> isConnectionCreateStatement(assign.getValue()).isPresent()); + .filter( + assign -> + isConnectionCreateStatement(ASTs.resolveLocalExpression(assign.getValue())) + .isPresent()); if (maybeSingleAssigned.isPresent()) { return maybeSingleAssigned.map(a -> Either.right(Either.left(a))); @@ -263,6 +247,25 @@ private Optional validateExecuteCall(final MethodCallExpr execut return Optional.of(stmtObject); } + private Optional> + validateStatementCreationExprForHijack( + final Either> stmtObject) { + if (stmtObject.isRight()) { + var maybelvd = + stmtObject + .getRight() + .ifLeftOrElseGet( + ae -> + ASTs.findEarliestLocalVariableDeclarationOf( + ae, ae.getTarget().asNameExpr().getNameAsString()), + lvd -> Optional.of(lvd)); + if (maybelvd.filter(lvd -> lvd instanceof ExpressionStmtVariableDeclaration).isPresent()) { + return Optional.of(stmtObject.getRight()); + } + } + return Optional.empty(); + } + /** Checks if a local declaration can change types to a subtype. */ private boolean canChangeTypes(final LocalVariableDeclaration localDeclaration) { final var allNameExpr = @@ -325,8 +328,8 @@ private boolean validateTryResource( return false; } - private String generateNameWithSuffix(final String name, final Node start) { - String actualName = preparedStatementNamePrefix; + private String generateNameWithSuffix(final Node start) { + String actualName = SQLParameterizer.preparedStatementNamePrefix; var maybeName = ASTs.findNonCallableSimpleNameSource(start, actualName); // Try for statement if (maybeName.isPresent()) { @@ -346,8 +349,15 @@ private String generateNameWithSuffix(final String name, final Node start) { return count == 0 ? actualName : nameWithSuffix; } + /** + * Fix the injections by replacing the injected expressions with a `?` parameter. + * + * @param injections A list of deques representing the expressions. + * @param resolvedMap A map containing the resolution of several expressions + * @return The list of expressions that were being injected + */ private List fixInjections( - final List> injections, Map resolvedMap) { + final List> injections, final Map resolvedMap) { final List combinedExpressions = new ArrayList<>(); for (final var injection : injections) { // fix start @@ -370,16 +380,12 @@ private List fixInjections( // build expression for parameters var combined = buildParameter(injection, resolvedMap); // add the suffix of start - if (prepend != "") { - final var newCombined = - new BinaryExpr(new StringLiteralExpr(prepend), combined, Operator.PLUS); - combined = newCombined; + if (!prepend.isEmpty()) { + combined = new BinaryExpr(new StringLiteralExpr(prepend), combined, Operator.PLUS); } // add the prefix of end - if (append != "") { - final var newCombined = - new BinaryExpr(combined, new StringLiteralExpr(append), Operator.PLUS); - combined = newCombined; + if (!append.isEmpty()) { + combined = new BinaryExpr(combined, new StringLiteralExpr(append), Operator.PLUS); } combinedExpressions.add(combined); } @@ -423,17 +429,62 @@ private Expression buildParameter( } /** - * The fix consists of the following: + * Parameterize the query strings and add the `setParameter` calls. + * + * @param pStatementVariableName The name of the PreparedStatemetnVariable that is used as a scope + * for the `setParameter` calls. + * @param anchoringStatement The statement that the `setParameter` calls will precede. + * @param parameterizedQuery The parameterized query strings. + * @return A statement that contains the start of + */ + private Statement gatherAndSetParameters( + final String pStatementVariableName, + final Statement anchoringStatement, + final QueryParameterizer parameterizedQuery) { + // Parameterize the query strings + final var queryParameters = + fixInjections( + parameterizedQuery.getInjections(), + parameterizedQuery.getLinearizedQuery().getResolvedExpressionsMap()); + + // Set the PreparedStatement parameters + var topStatement = anchoringStatement; + for (int i = queryParameters.size() - 1; i >= 0; i--) { + final var expr = queryParameters.get(i); + ExpressionStmt setStmt; + setStmt = + new ExpressionStmt( + new MethodCallExpr( + new NameExpr(pStatementVariableName), + "setString", + new NodeList<>(new IntegerLiteralExpr(String.valueOf(i + 1)), expr))); + ASTTransforms.addStatementBeforeStatement(topStatement, setStmt); + topStatement = setStmt; + } + + ASTTransforms.addImportIfMissing(compilationUnit, "java.sql.PreparedStatement"); + return topStatement; + } + + /** + * Apply the fix for the parameterization, which consists of the following steps: * *

(0) If the execute call is the following resource, break the try into two statements; * - *

(1.a) Create a new PreparedStatement pstmt object; + *

(1) Add a setString for every injection parameter; * - *

(1.b) Change Statement type to PreparedStatement and createStatement to prepareStatement; + *

(2.a) Create a new PreparedStatement pstmt object; * - *

(2) Add a setString for every injection parameter; + *

(2.b) Change Statement type to PreparedStatement and createStatement to prepareStatement; * *

(3) Change .execute*() to pstmt.execute(). + * + * @param stmtCreation Either a declaration of a java.sql.Statement object, assingment of a + * java.sql.Statement object, or a conn.createStatement() call; + * @param queryParameterizer The QueryParameterizer object that containing the query strings and + * parameter expressions + * @param executeCall The `.execute*()` call. + * @return */ private MethodCallExpr fix( final Either> stmtCreation, @@ -463,39 +514,27 @@ private MethodCallExpr fix( final String stmtName = stmtCreation.ifLeftOrElseGet( - mce -> generateNameWithSuffix(preparedStatementNamePrefix, mce), + mce -> generateNameWithSuffix(mce), assignOrLVD -> assignOrLVD.ifLeftOrElseGet( - a -> a.getTarget().asNameExpr().getNameAsString(), lvd -> lvd.getName())); + a -> a.getTarget().asNameExpr().getNameAsString(), + LocalVariableDeclaration::getName)); // (1) - final var combinedExpressions = - fixInjections( - queryParameterizer.getInjections(), - queryParameterizer.getLinearizedQuery().getResolvedExpressionsMap()); - - var topStatement = executeStmt; - for (int i = combinedExpressions.size() - 1; i >= 0; i--) { - final var expr = combinedExpressions.get(i); - ExpressionStmt setStmt = null; - setStmt = - new ExpressionStmt( - new MethodCallExpr( - new NameExpr(stmtName), - "setString", - new NodeList<>(new IntegerLiteralExpr(String.valueOf(i + 1)), expr))); - ASTTransforms.addStatementBeforeStatement(topStatement, setStmt); - topStatement = setStmt; - } + var topStatement = gatherAndSetParameters(stmtName, executeStmt, queryParameterizer); - ASTTransforms.addImportIfMissing(compilationUnit, "java.sql.PreparedStatement"); + // (3) + executeCall.setName("execute"); + executeCall.setScope(new NameExpr(stmtName)); + executeCall.setArguments(new NodeList<>()); // (2) + // Gather execute call arguments final var args = new NodeList(); args.addFirst(queryParameterizer.getRoot()); args.addAll( stmtCreation.ifLeftOrElseGet( - mce -> mce.getArguments(), + MethodCallExpr::getArguments, assignOrLVD -> assignOrLVD.ifLeftOrElseGet( a -> a.getValue().asMethodCallExpr().getArguments(), @@ -506,75 +545,84 @@ private MethodCallExpr fix( .asMethodCallExpr() .getArguments()))); - // (3) - executeCall.setName("execute"); - executeCall.setScope(new NameExpr(stmtName)); - executeCall.setArguments(new NodeList<>()); - + // Create the `prepareStatement()` call and return it MethodCallExpr pstmtCreation; - // (2.a) + // Treat each of the three cases separately + // (2.a) The statement is created directly from the Connection without a middle variable for the + // java.sql.Statement if (stmtCreation.isLeft()) { - pstmtCreation = - new MethodCallExpr(stmtCreation.getLeft().getScope().get(), "prepareStatement", args); - final var pstmtCreationStmt = - new ExpressionStmt( - new VariableDeclarationExpr( - new VariableDeclarator( - StaticJavaParser.parseType("PreparedStatement"), stmtName, pstmtCreation))); - ASTTransforms.addStatementBeforeStatement(topStatement, pstmtCreationStmt); - - // (2.b) + // (2.b) The statement is created directly and assigned to a named variable + pstmtCreation = createPSWithoutVariable(stmtCreation.getLeft(), args, topStatement, stmtName); } else { + // The statement is created with an assignment or declaration final var assignOrLVD = stmtCreation.getRight(); - if (assignOrLVD.isLeft()) { - pstmtCreation = assignOrLVD.getLeft().getValue().asMethodCallExpr(); - pstmtCreation.setArguments(args); - pstmtCreation.setName("prepareStatement"); - - // change the assignment - assignOrLVD.getLeft().setValue(StaticJavaParser.parseExpression("a")); - assignOrLVD.getLeft().setValue(pstmtCreation); - - // change the initialization to be null and its type to PreparedStatement - // This will only work assuming a single shadowing assignment, may require changes here in - // the future - var maybeLVD = - ASTs.findEarliestLocalVariableDeclarationOf( - assignOrLVD.getLeft().getTarget(), - assignOrLVD.getLeft().getTarget().asNameExpr().getNameAsString()); - if (maybeLVD.isPresent()) { - var vd = maybeLVD.get().getVariableDeclarator(); - vd.setInitializer(new NullLiteralExpr()); - vd.setType(StaticJavaParser.parseType("PreparedStatement")); - } + pstmtCreation = + assignOrLVD.ifLeftOrElseGet( + ae -> createPSFromAE(ae, args), lvd -> createPSFromLVD(lvd, args)); + } + return pstmtCreation; + } - } else { - assignOrLVD - .getRight() - .getVariableDeclarator() - .setType(StaticJavaParser.parseType("PreparedStatement")); - assignOrLVD - .getRight() - .getVariableDeclarator() - .getInitializer() - .ifPresent(expr -> expr.asMethodCallExpr().setName("prepareStatement")); - assignOrLVD - .getRight() - .getVariableDeclarator() - .getInitializer() - .ifPresent(expr -> expr.asMethodCallExpr().setArguments(args)); - pstmtCreation = - assignOrLVD - .getRight() - .getVariableDeclarator() - .getInitializer() - .get() - .asMethodCallExpr(); - } + private MethodCallExpr createPSWithoutVariable( + final MethodCallExpr directStatementCreation, + final NodeList args, + final Statement anchoringStatement, + final String stmtName) { + var pstmtCreation = + new MethodCallExpr(directStatementCreation.getScope().get(), "prepareStatement", args); + final var pstmtCreationStmt = + new ExpressionStmt( + new VariableDeclarationExpr( + new VariableDeclarator( + StaticJavaParser.parseType("PreparedStatement"), stmtName, pstmtCreation))); + ASTTransforms.addStatementBeforeStatement(anchoringStatement, pstmtCreationStmt); + return pstmtCreation; + } + + private MethodCallExpr createPSFromAE( + final AssignExpr assignExpr, final NodeList args) { + var pstmtCreation = assignExpr.getValue().asMethodCallExpr(); + pstmtCreation.setArguments(args); + pstmtCreation.setName("prepareStatement"); + + // change the assignment + assignExpr.setValue(StaticJavaParser.parseExpression("a")); + assignExpr.setValue(pstmtCreation); + + // change the initialization to be null and its type to PreparedStatement + // This will only work assuming a single shadowing assignment, may require changes here in + // the future + var maybeLVD = + ASTs.findEarliestLocalVariableDeclarationOf( + assignExpr.getTarget(), assignExpr.getTarget().asNameExpr().getNameAsString()); + if (maybeLVD.isPresent()) { + var vd = maybeLVD.get().getVariableDeclarator(); + vd.setInitializer(new NullLiteralExpr()); + vd.setType(StaticJavaParser.parseType("PreparedStatement")); } return pstmtCreation; } + private MethodCallExpr createPSFromLVD( + final LocalVariableDeclaration localVariableDeclaration, final NodeList args) { + localVariableDeclaration + .getVariableDeclarator() + .setType(StaticJavaParser.parseType("PreparedStatement")); + localVariableDeclaration + .getVariableDeclarator() + .getInitializer() + .ifPresent(expr -> expr.asMethodCallExpr().setName("prepareStatement")); + localVariableDeclaration + .getVariableDeclarator() + .getInitializer() + .ifPresent(expr -> expr.asMethodCallExpr().setArguments(args)); + return localVariableDeclaration + .getVariableDeclarator() + .getInitializer() + .get() + .asMethodCallExpr(); + } + private boolean resolvedInScope( final Either assignOrLVD, Expression expr) { if (assignOrLVD.isLeft()) { @@ -607,16 +655,107 @@ private boolean assignedOrDefinedInScope( final boolean assignedInScope = assignmentsInScope .flatMap(aexpr -> ASTs.hasNamedTarget(aexpr).stream()) - .anyMatch(nexpr -> nexpr.getNameAsString() == name.getNameAsString()); + .anyMatch(nexpr -> Objects.equals(nexpr.getNameAsString(), name.getNameAsString())); final boolean definedInScope = - ASTs.findNonCallableSimpleNameSource(name.getName()) - .filter(source -> scope.inScope(source)) - .isPresent(); + ASTs.findNonCallableSimpleNameSource(name.getName()).filter(scope::inScope).isPresent(); return assignedInScope || definedInScope; } + private Expression getConnectionExpression( + final Either stmtCreation) { + return stmtCreation + .ifLeftOrElseGet( + ae -> ASTs.resolveLocalExpression(ae.getValue()).asMethodCallExpr(), + lvd -> lvd.getDeclaration().getInitializer().get().asMethodCallExpr()) + .getScope() + .get(); + } + + private MethodCallExpr fixByHijackedStatement( + final Either stmtCreation, + final QueryParameterizer queryParameterizer, + final MethodCallExpr executeCall) { + var executeStmt = ASTs.findParentStatementFrom(executeCall).get(); + // get the statement object variable name + final String stmtName = + stmtCreation.ifLeftOrElseGet( + a -> a.getTarget().asNameExpr().getNameAsString(), LocalVariableDeclaration::getName); + // generate a name for the new PreparedStatement object + String pStmtName = generateNameWithSuffix(executeCall); + + final String connName = getConnectionExpression(stmtCreation).asNameExpr().getNameAsString(); + + var topStatement = executeStmt; + + // Replace the parameters with the `?` string and adds the `setParameter` calls + // Also, get the top `setParameter` statement + topStatement = gatherAndSetParameters(pStmtName, topStatement, queryParameterizer); + + // Add PreparedStmt stmt = conn.prepareStatement() assignment + // Need to clone the nodes in the arguments to make sure the parent node is properly set + MethodCallExpr prepareStatementCall = + new MethodCallExpr( + new NameExpr(connName), + "prepareStatement", + new NodeList<>(executeCall.getArguments().stream().map(n -> n.clone()).toList())); + ExpressionStmt pStmtCreation = + new ExpressionStmt( + new VariableDeclarationExpr( + new VariableDeclarator( + StaticJavaParser.parseType("PreparedStatement"), + pStmtName, + prepareStatementCall))); + ASTTransforms.addStatementBeforeStatement(topStatement, pStmtCreation); + topStatement = pStmtCreation; + ASTTransforms.addImportIfMissing(compilationUnit, "java.sql.PreparedStatement"); + + // Test if stmt.execute*() is the first usage of the stmt object + // If so, remove initializer + // otherwise add stmt.close() + if (isExecuteFirstUsageAfterDeclaration(stmtCreation, executeCall)) { + var lvd = stmtCreation.getRight(); + lvd.getVariableDeclarator().getInitializer().ifPresent(i -> i.remove()); + } else { + Statement closeOriginal = + new ExpressionStmt(new MethodCallExpr(new NameExpr(stmtName), new SimpleName("close"))); + ASTTransforms.addStatementBeforeStatement(topStatement, closeOriginal); + } + + // TODO will this work for every type of execute statement? or just executeQuery? + // change execute statement + executeCall.setName("execute"); + executeCall.setScope(new NameExpr(pStmtName)); + executeCall.setArguments(new NodeList<>()); + + // add stmt = pstmt after executeCall + Statement hijackAssignment = + new ExpressionStmt( + new AssignExpr( + new NameExpr(stmtName), new NameExpr(pStmtName), AssignExpr.Operator.ASSIGN)); + ASTTransforms.addStatementAfterStatement(executeStmt, hijackAssignment); + + return prepareStatementCall; + } + + private boolean isExecuteFirstUsageAfterDeclaration( + final Either stmtCreation, + final MethodCallExpr executeCall) { + if (stmtCreation.isRight()) { + var lvd = stmtCreation.getRight(); + // This is heuristics + return ASTs.findAllReferences(lvd).stream() + .findFirst() + .flatMap(e -> ASTs.isScopeInMethodCall(e)) + .filter(mce -> mce == executeCall) + .isPresent(); + } + // We could also apply this predicate to assignments and remove it, but that may require more + // checks + return false; + } + /** * Checks if {@code methodCall} is a query call that needs to be fixed and fixes if that's the * case. If the parameterization happened, returns the PreparedStatement creation. @@ -630,8 +769,7 @@ public Optional checkAndFix() { // validate the call itself first if (isParameterizationCandidate(executeCall) && validateExecuteCall(executeCall).isPresent()) { // Now find the stmt creation expression, if any and validate it - final var stmtObject = - findStatementCreationExpr(executeCall).flatMap(this::validateStatementCreationExpr); + final var stmtObject = findStatementCreationExpr(executeCall); if (stmtObject.isPresent()) { // Now look for injections @@ -652,8 +790,8 @@ public Optional checkAndFix() { queryp.getLinearizedQuery().getResolvedExpressionsMap().keySet().stream() .anyMatch(expr -> resolvedInScope(assignOrLVD, expr))); - //// Is any name in the linearized expression defined/assigned inside the scope of the - // Statement Object? + ////// Is any name in the linearized expression defined/assigned inside the scope of the + //// Statement Object? final boolean nameInScope = stmtObject .get() @@ -661,15 +799,26 @@ public Optional checkAndFix() { mcd -> false, assignOrLVD -> queryp.getLinearizedQuery().getLinearized().stream() - .filter(expr -> expr.isNameExpr()) - .map(expr -> expr.asNameExpr()) + .filter(Expression::isNameExpr) + .map(Expression::asNameExpr) .anyMatch(name -> assignedOrDefinedInScope(name, assignOrLVD))); - if (queryp.getInjections().isEmpty() || resolvedInScope || nameInScope) { + // No injections detected + if (queryp.getInjections().isEmpty()) { return Optional.empty(); } - return Optional.of(fix(stmtObject.get(), queryp, executeCall)); + // This means we can replace the Statement declaration or assignment + if (!nameInScope + && !resolvedInScope + && stmtObject.flatMap(this::validateStatementCreationExpr).isPresent()) { + return Optional.of(fix(stmtObject.get(), queryp, executeCall)); + } + // Otherwise we use the hijack strategy + var maybeStmtObject = stmtObject.flatMap(this::validateStatementCreationExprForHijack); + if (maybeStmtObject.isPresent()) { + return Optional.of(fixByHijackedStatement(maybeStmtObject.get(), queryp, executeCall)); + } } } return Optional.empty(); diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index d5c3bd51e..151a1040b 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -1,7 +1,7 @@ [versions] auto-value = "1.9" jackson = "2.13.1" -javaparser-core = "3.25.4" +javaparser-core = "3.26.2" javaparser-symbolsolver = "3.15.15" java-security-toolkit = "1.2.0" java-security-toolkit-xstream = "1.0.2"