diff --git a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java index 7836ca0e9..52c9c9200 100644 --- a/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java +++ b/src/main/java/org/openrewrite/java/testing/mockito/MockitoWhenOnStaticToMockStatic.java @@ -75,34 +75,19 @@ private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.Meth if (when != null && when.getArguments().get(0) instanceof J.MethodInvocation) { J.MethodInvocation whenArg = (J.MethodInvocation) when.getArguments().get(0); if (whenArg.getMethodType() != null && whenArg.getMethodType().hasFlags(Flag.Static)) { - J.Identifier clazz = (J.Identifier) whenArg.getSelect(); - if (clazz != null && clazz.getType() != null) { - String mockName = VariableNameUtils.generateVariableName("mock" + clazz.getSimpleName(), updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); - maybeAddImport("org.mockito.MockedStatic", false); - maybeAddImport("org.mockito.Mockito", "mockStatic"); - String template = String.format( - "try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n" + - " %2$s.when(#{any()}).thenReturn(#{any()});\n" + - "}", clazz.getSimpleName(), mockName); - J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) - .contextSensitive() - .imports("org.mockito.MockedStatic") - .staticImports("org.mockito.Mockito.mockStatic") - .build() - .apply(getCursor(), m.getCoordinates().replaceBody(), - whenArg, ((J.MethodInvocation) statement).getArguments().get(0))) - .getBody().getStatements().get(0); - - restInTry.set(true); - - List precedingStatements = remainingStatements.subList(0, index); - return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( - try_.getBody().getStatements(), - maybeWrapStatementsInTryWithResourcesMockedStatic( - m.withBody(m.getBody().withStatements(ListUtils.concat(precedingStatements, try_))), - remainingStatements.subList(index + 1, remainingStatements.size()) - )))) - .withPrefix(statement.getPrefix()); + if (whenArg.getSelect() instanceof J.Identifier) { + J.Identifier clazz = (J.Identifier) whenArg.getSelect(); + if (clazz.getType() != null) { + return tryWithMockedStatic(m, remainingStatements, index, statement, clazz.getSimpleName(), whenArg, restInTry); + } + } else if (whenArg.getSelect() instanceof J.FieldAccess) { + J.FieldAccess fieldAccess = (J.FieldAccess) whenArg.getSelect(); + if (fieldAccess.getTarget() instanceof J.Identifier) { + J.Identifier clazz = (J.Identifier) fieldAccess.getTarget(); + if (clazz.getType() != null) { + return tryWithMockedStatic(m, remainingStatements, index, statement, clazz.getSimpleName(), whenArg, restInTry); + } + } } } } @@ -110,6 +95,42 @@ private List maybeWrapStatementsInTryWithResourcesMockedStatic(J.Meth return statement; }); } + + private J.Try tryWithMockedStatic( + J.MethodDeclaration m, + List remainingStatements, + Integer index, + Statement statement, + String simpleName, + J.MethodInvocation whenArg, + AtomicBoolean restInTry) { + String mockName = VariableNameUtils.generateVariableName("mock" + simpleName, updateCursor(m), VariableNameUtils.GenerationStrategy.INCREMENT_NUMBER); + maybeAddImport("org.mockito.MockedStatic", false); + maybeAddImport("org.mockito.Mockito", "mockStatic"); + String template = String.format( + "try(MockedStatic<%1$s> %2$s = mockStatic(%1$s.class)) {\n" + + " %2$s.when(#{any()}).thenReturn(#{any()});\n" + + "}", simpleName, mockName); + J.Try try_ = (J.Try) ((J.MethodDeclaration) JavaTemplate.builder(template) + .contextSensitive() + .imports("org.mockito.MockedStatic") + .staticImports("org.mockito.Mockito.mockStatic") + .build() + .apply(getCursor(), m.getCoordinates().replaceBody(), + whenArg, ((J.MethodInvocation) statement).getArguments().get(0))) + .getBody().getStatements().get(0); + + restInTry.set(true); + + List precedingStatements = remainingStatements.subList(0, index); + return try_.withBody(try_.getBody().withStatements(ListUtils.concatAll( + try_.getBody().getStatements(), + maybeWrapStatementsInTryWithResourcesMockedStatic( + m.withBody(m.getBody().withStatements(ListUtils.concat(precedingStatements, try_))), + remainingStatements.subList(index + 1, remainingStatements.size()) + )))) + .withPrefix(statement.getPrefix()); + } }); } }