Skip to content

Commit

Permalink
Prevent class cast exception in MockitoWhenOnStaticToMockStatic
Browse files Browse the repository at this point in the history
Fixes #644
  • Loading branch information
timtebeek committed Nov 25, 2024
1 parent 0f8f838 commit 72dbc65
Showing 1 changed file with 49 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -75,41 +75,62 @@ private List<Statement> maybeWrapStatementsInTryWithResourcesMockedStatic(J.Meth
if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) {
J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0);
if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Flag.Static)) {
J.Identifier clazz = (J.Identifier) whenArg.getSelect();
if (clazz != null && clazz.getType() != null) {
String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER);
maybeAddImport("org.mockito.MockedStatic", false);
maybeAddImport("org.mockito.Mockito", "mockStatic");
String template = String.format(
"try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n" +
" %2$s.when(#{any()}).thenReturn(#{any()});\n" +
"}", clazz.getSimpleName(), mockName);
J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template)
.contextSensitive()
.imports("org.mockito.MockedStatic")
.staticImports("org.mockito.Mockito.mockStatic")
.build()
.apply(getCursor(), m.getCoordinates().replaceBody(),
whenArg, ((J.MethodInvocation) statement).getArguments().get(0)))
.getBody().getStatements().get(0);

restInTry.set(true);

List<Statement> precedingStatements = remainingStatements.subList(0, index);
return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll(
try_.getBody().getStatements(),
maybeWrapStatementsInTryWithResourcesMockedStatic(
m.withBody(m.getBody().withStatements(ListUtils.concat(precedingStatements, try_))),
remainingStatements.subList(index + 1, remainingStatements.size())
))))
.withPrefix(statement.getPrefix());
if (whenArg.getSelect() instanceof J.Identifier) {
J.Identifier clazz = (J.Identifier) whenArg.getSelect();
if (clazz.getType() != null) {
return tryWithMockedStatic(m, remainingStatements, index, statement, clazz.getSimpleName(), whenArg, restInTry);
}
} else if (whenArg.getSelect() instanceof J.FieldAccess) {
J.FieldAccess fieldAccess = (J.FieldAccess) whenArg.getSelect();
if (fieldAccess.getTarget() instanceof J.Identifier) {
J.Identifier clazz = (J.Identifier) fieldAccess.getTarget();
if (clazz.getType() != null) {
return tryWithMockedStatic(m, remainingStatements, index, statement, clazz.getSimpleName(), whenArg, restInTry);
}
}
}
}
}
}
return statement;
});
}

private J.Try tryWithMockedStatic(
J.MethodDeclaration m,
List<Statement> remainingStatements,
Integer index,
Statement statement,
String simpleName,
J.MethodInvocation whenArg,
AtomicBoolean restInTry) {
String mockName = VariableNameUtils.generateVariableName("mock" + simpleName, updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER);
maybeAddImport("org.mockito.MockedStatic", false);
maybeAddImport("org.mockito.Mockito", "mockStatic");
String template = String.format(
"try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n" +
" %2$s.when(#{any()}).thenReturn(#{any()});\n" +
"}", simpleName, mockName);
J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template)
.contextSensitive()
.imports("org.mockito.MockedStatic")
.staticImports("org.mockito.Mockito.mockStatic")
.build()
.apply(getCursor(), m.getCoordinates().replaceBody(),
whenArg, ((J.MethodInvocation) statement).getArguments().get(0)))
.getBody().getStatements().get(0);

restInTry.set(true);

List<Statement> precedingStatements = remainingStatements.subList(0, index);
return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll(
try_.getBody().getStatements(),
maybeWrapStatementsInTryWithResourcesMockedStatic(
m.withBody(m.getBody().withStatements(ListUtils.concat(precedingStatements, try_))),
remainingStatements.subList(index + 1, remainingStatements.size())
))))
.withPrefix(statement.getPrefix());
}
});
}
}

0 comments on commit 72dbc65

Please sign in to comment.