diff --git a/src/main/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertions.java b/src/main/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertions.java index 8c49812d..b2b0cd53 100644 --- a/src/main/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertions.java +++ b/src/main/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertions.java @@ -15,6 +15,10 @@ import org.openrewrite.java.template.RecipeDescriptor; import org.testng.Assert; +import java.util.Iterator; +import java.util.Spliterators; +import java.util.stream.StreamSupport; + @RecipeDescriptor( name = "Migrate TestNG Asserts to Jupiter", description = "Migrate all TestNG Assertions to JUnit Jupiter Assertions." @@ -51,10 +55,49 @@ public static class MigrateAssertEqualsArrayWithMsg { } } + @RecipeDescriptor( + name = "Migrate `Assert#assertEquals(Iterator, Iterator)`", + description = "Migrates `org.testng.Assert#assertEquals(Iterator, Iterator)` " + + "to `org.junit.jupiter.api.Assertions#assertArrayEquals(Object[], Object[])`." + ) + public static class MigrateAssertEqualsIterator { + + @BeforeTemplate void before(Iterator actual, Iterator expected) { + Assert.assertEquals(actual, expected); + } + + @AfterTemplate void after(Iterator actual, Iterator expected) { + Assertions.assertArrayEquals( + StreamSupport.stream(Spliterators.spliteratorUnknownSize(expected, 0), false).toArray(), + StreamSupport.stream(Spliterators.spliteratorUnknownSize(actual, 0), false).toArray() + ); + } + } + + @RecipeDescriptor( + name = "Migrate `Assert#assertEquals(Iterator, Iterator, String)`", + description = "Migrates `org.testng.Assert#assertEquals(Iterator, Iterator, String)` " + + "to `org.junit.jupiter.api.Assertions#assertArrayEquals(Object[], Object[], String)`." + ) + public static class MigrateAssertEqualsIteratorWithMsg { + + @BeforeTemplate void before(Iterator actual, Iterator expected, String msg) { + Assert.assertEquals(actual, expected, msg); + } + + @AfterTemplate void after(Iterator actual, Iterator expected, String msg) { + Assertions.assertArrayEquals( + StreamSupport.stream(Spliterators.spliteratorUnknownSize(expected, 0), false).toArray(), + StreamSupport.stream(Spliterators.spliteratorUnknownSize(actual, 0), false).toArray(), + msg + ); + } + } + @RecipeDescriptor( name = "Replace `Assert#assertEquals(?, ?)` for primitive values, boxed types and other non-array objects", description = "Replace `org.testng.Assert#assertEquals(?, ?)` with `org.junit.jupiter.api.Assertions#assertEquals(?, ?)`." - + "Always run *after* `MigrateAssertEqualsArrayRecipe`." + + "Always run *after* `MigrateAssertEqualsArrayRecipe` and `MigrateAssertEqualsIteratorRecipe`." ) public static class MigrateAssertEquals { @@ -70,7 +113,7 @@ public static class MigrateAssertEquals { @RecipeDescriptor( name = "Replace `Assert#assertEquals(?, ?, String)` for primitive values, boxed types and other non-array objects", description = "Replace `org.testng.Assert#assertEquals(?, ?, String)` with `org.junit.jupiter.api.Assertions#assertEquals(?, ?, String)`." - + "Always run *after* `MigrateAssertEqualsArrayWithMsgRecipe`." + + "Always run *after* `MigrateAssertEqualsArrayRecipe` and `MigrateAssertEqualsIteratorRecipe`." ) public static class MigrateAssertEqualsWithMsg { diff --git a/src/test/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertionsTests.java b/src/test/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertionsTests.java index 85cfdf31..75ad3a34 100644 --- a/src/test/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertionsTests.java +++ b/src/test/java/io/github/mboegers/openrewrite/testngtojupiter/MigrateAssertionsTests.java @@ -171,6 +171,43 @@ void testMethod() { """.formatted(actual, expected) )); } + + @Test void becomesSpecialAssertArrayEquals_forIterators() { + // language=java + rewriteRun(java( + """ + import java.util.Iterator; + import java.util.List; + import org.testng.Assert; + + class MyTest { + void testMethod() { + Iterator actual = List.of("a", "b").iterator(); + Iterator expected = List.of("b", "a").iterator(); + + Assert.assertEquals(actual, expected, "Kaboom."); + } + } + """, + """ + import org.junit.jupiter.api.Assertions; + + import java.util.Iterator; + import java.util.List; + import java.util.Spliterators; + import java.util.stream.StreamSupport; + + class MyTest { + void testMethod() { + Iterator actual = List.of("a", "b").iterator(); + Iterator expected = List.of("b", "a").iterator(); + + Assertions.assertArrayEquals(StreamSupport.stream(Spliterators.spliteratorUnknownSize(expected, 0), false).toArray(), StreamSupport.stream(Spliterators.spliteratorUnknownSize(actual, 0), false).toArray(), "Kaboom."); + } + } + """ + )); + } } @Nested class WithoutErrorMessage { @@ -236,6 +273,43 @@ void testMethod() { """.formatted(actual, expected) )); } + + @Test void becomesSpecialAssertArrayEquals_forIterators() { + // language=java + rewriteRun(java( + """ + import java.util.Iterator; + import java.util.List; + import org.testng.Assert; + + class MyTest { + void testMethod() { + Iterator actual = List.of("a", "b").iterator(); + Iterator expected = List.of("b", "a").iterator(); + + Assert.assertEquals(actual, expected); + } + } + """, + """ + import org.junit.jupiter.api.Assertions; + + import java.util.Iterator; + import java.util.List; + import java.util.Spliterators; + import java.util.stream.StreamSupport; + + class MyTest { + void testMethod() { + Iterator actual = List.of("a", "b").iterator(); + Iterator expected = List.of("b", "a").iterator(); + + Assertions.assertArrayEquals(StreamSupport.stream(Spliterators.spliteratorUnknownSize(expected, 0), false).toArray(), StreamSupport.stream(Spliterators.spliteratorUnknownSize(actual, 0), false).toArray()); + } + } + """ + )); + } } } diff --git a/src/test/java/org/philzen/oss/research/ApiComparisonTest.java b/src/test/java/org/philzen/oss/research/ApiComparisonTest.java index 7d32b233..fb7a5336 100644 --- a/src/test/java/org/philzen/oss/research/ApiComparisonTest.java +++ b/src/test/java/org/philzen/oss/research/ApiComparisonTest.java @@ -1,5 +1,6 @@ package org.philzen.oss.research; +import com.google.common.collect.ImmutableList; import org.assertj.core.api.ThrowableAssert; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Nested; @@ -8,12 +9,18 @@ import org.testng.Assert; import java.util.Arrays; +import java.util.Collection; +import java.util.Spliterators; +import java.util.stream.StreamSupport; import static org.assertj.core.api.Assertions.assertThatExceptionOfType; import static org.assertj.core.api.Assertions.assertThatNoException; class ApiComparisonTest { + static final Collection ABC_list = ImmutableList.of("a", "b", "c"); + static final Collection CBA_list = ImmutableList.of("c", "b", "a"); + static final String[] ABC_array = {"a", "b", "c"}; static final String[] CBA_array = {"c", "b", "a"}; @@ -30,6 +37,26 @@ class ApiComparisonTest { // possible migration thisWillPass(() -> Assertions.assertEquals(Arrays.toString(ABC_array), Arrays.toString(ABC_array.clone()))); } + + @Tag("mismatch") + @Test void iterator() { + thisWillFail(() -> Assert.assertEquals(ABC_list.iterator(), CBA_list.iterator())); + thisWillFail(() -> Assertions.assertEquals(CBA_list.iterator(), ABC_list.iterator())); + + thisWillPass(() -> Assert.assertEquals(ABC_list.iterator(), ABC_list.iterator())); + thisWillFail(() -> Assertions.assertEquals(ABC_list.iterator(), ABC_list.iterator())); + + // possible migration + thisWillPass(() -> Assertions.assertArrayEquals( + StreamSupport.stream(Spliterators.spliteratorUnknownSize(ABC_list.iterator(), 0), false).toArray(), + StreamSupport.stream(Spliterators.spliteratorUnknownSize(ABC_list.iterator(), 0), false).toArray() + )); + + thisWillFail(() -> Assertions.assertArrayEquals( + StreamSupport.stream(Spliterators.spliteratorUnknownSize(ABC_list.iterator(), 0), false).toArray(), + StreamSupport.stream(Spliterators.spliteratorUnknownSize(CBA_list.iterator(), 0), false).toArray() + )); + } } void thisWillPass(final ThrowableAssert.ThrowingCallable code) {