Skip to content

Commit

Permalink
update parameterized test
Browse files Browse the repository at this point in the history
  • Loading branch information
llama90 committed Jun 7, 2024
1 parent 032d823 commit 31d8cf6
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 184 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@

import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.stream.Stream;

import org.apache.arrow.algorithm.sort.DefaultVectorComparators;
import org.apache.arrow.algorithm.sort.VectorValueComparator;
Expand All @@ -35,64 +35,57 @@
import org.apache.arrow.vector.VarCharVector;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

/**
* Test cases for {@link ParallelSearcher}.
*/
@RunWith(Parameterized.class)
public class TestParallelSearcher {

private enum ComparatorType {
public enum ComparatorType {
EqualityComparator,
OrderingComparator;
}

private static final int VECTOR_LENGTH = 10000;

private final int threadCount;

private BufferAllocator allocator;

private ExecutorService threadPool;

private final ComparatorType comparatorType;

public TestParallelSearcher(ComparatorType comparatorType, int threadCount) {
this.comparatorType = comparatorType;
this.threadCount = threadCount;
}

@Parameterized.Parameters(name = "comparator type = {0}, thread count = {1}")
public static Collection<Object[]> getComparatorName() {
List<Object[]> params = new ArrayList<>();
static Stream<Arguments> getComparatorName() {
List<Arguments> params = new ArrayList<>();
int[] threadCounts = {1, 2, 5, 10, 20, 50};
for (ComparatorType type : ComparatorType.values()) {
for (int count : threadCounts) {
params.add(new Object[] {type, count});
params.add(Arguments.of(type, count));
}
}
return params;
return params.stream();
}

@BeforeEach
public void prepare() {
allocator = new RootAllocator(1024 * 1024);
threadPool = Executors.newFixedThreadPool(threadCount);
}

@AfterEach
public void shutdown() {
allocator.close();
threadPool.shutdown();
if (threadPool != null) {
threadPool.shutdown();
}
}

@Test
public void testParallelIntSearch() throws ExecutionException, InterruptedException {
@ParameterizedTest(name = "comparator type = {0}, thread count = {1}")
@MethodSource("getComparatorName")
public void testParallelIntSearch(ComparatorType comparatorType, int threadCount)
throws ExecutionException, InterruptedException {
threadPool = Executors.newFixedThreadPool(threadCount);
try (IntVector targetVector = new IntVector("targetVector", allocator);
IntVector keyVector = new IntVector("keyVector", allocator)) {
IntVector keyVector = new IntVector("keyVector", allocator)) {
targetVector.allocateNew(VECTOR_LENGTH);
keyVector.allocateNew(VECTOR_LENGTH);

Expand All @@ -119,8 +112,11 @@ public void testParallelIntSearch() throws ExecutionException, InterruptedExcept
}
}

@Test
public void testParallelStringSearch() throws ExecutionException, InterruptedException {
@ParameterizedTest(name = "comparator type = {0}, thread count = {1}")
@MethodSource("getComparatorName")
public void testParallelStringSearch(ComparatorType comparatorType, int threadCount)
throws ExecutionException, InterruptedException {
threadPool = Executors.newFixedThreadPool(threadCount);
try (VarCharVector targetVector = new VarCharVector("targetVector", allocator);
VarCharVector keyVector = new VarCharVector("keyVector", allocator)) {
targetVector.allocateNew(VECTOR_LENGTH);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import static org.junit.jupiter.api.Assertions.assertEquals;

import java.util.Arrays;
import java.util.Collection;
import java.util.stream.Stream;

import org.apache.arrow.algorithm.sort.DefaultVectorComparators;
import org.apache.arrow.algorithm.sort.VectorValueComparator;
Expand All @@ -29,24 +28,17 @@
import org.apache.arrow.vector.IntVector;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

/**
* Test cases for {@link VectorRangeSearcher}.
*/
@RunWith(Parameterized.class)
public class TestVectorRangeSearcher {

private BufferAllocator allocator;

private int repeat;

public TestVectorRangeSearcher(int repeat) {
this.repeat = repeat;
}

@BeforeEach
public void prepare() {
allocator = new RootAllocator(1024 * 1024);
Expand All @@ -57,8 +49,9 @@ public void shutdown() {
allocator.close();
}

@Test
public void testGetLowerBounds() {
@ParameterizedTest(name = "repeat = {0}")
@MethodSource("getRepeat")
public void testGetLowerBounds(int repeat) {
final int maxValue = 100;
try (IntVector intVector = new IntVector("int vec", allocator)) {
// allocate vector
Expand Down Expand Up @@ -86,11 +79,12 @@ public void testGetLowerBounds() {
}
}

@Test
public void testGetLowerBoundsNegative() {
@ParameterizedTest(name = "repeat = {0}")
@MethodSource("getRepeat")
public void testGetLowerBoundsNegative(int repeat) {
final int maxValue = 100;
try (IntVector intVector = new IntVector("int vec", allocator);
IntVector negVector = new IntVector("neg vec", allocator)) {
IntVector negVector = new IntVector("neg vec", allocator)) {
// allocate vector
intVector.allocateNew(maxValue * repeat);
intVector.setValueCount(maxValue * repeat);
Expand Down Expand Up @@ -120,8 +114,9 @@ public void testGetLowerBoundsNegative() {
}
}

@Test
public void testGetUpperBounds() {
@ParameterizedTest(name = "repeat = {0}")
@MethodSource("getRepeat")
public void testGetUpperBounds(int repeat) {
final int maxValue = 100;
try (IntVector intVector = new IntVector("int vec", allocator)) {
// allocate vector
Expand Down Expand Up @@ -149,8 +144,9 @@ public void testGetUpperBounds() {
}
}

@Test
public void testGetUpperBoundsNegative() {
@ParameterizedTest(name = "repeat = {0}")
@MethodSource("getRepeat")
public void testGetUpperBoundsNegative(int repeat) {
final int maxValue = 100;
try (IntVector intVector = new IntVector("int vec", allocator);
IntVector negVector = new IntVector("neg vec", allocator)) {
Expand Down Expand Up @@ -183,13 +179,12 @@ public void testGetUpperBoundsNegative() {
}
}

@Parameterized.Parameters(name = "repeat = {0}")
public static Collection<Object[]> getRepeat() {
return Arrays.asList(
new Object[]{1},
new Object[]{2},
new Object[]{5},
new Object[]{10}
static Stream<Arguments> getRepeat() {
return Stream.of(
Arguments.of(1),
Arguments.of(2),
Arguments.of(5),
Arguments.of(10)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@
import org.apache.arrow.vector.testing.ValueVectorDataPopulator;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;

/**
* Test cases for {@link FixedWidthOutOfPlaceVectorSorter}.
Expand All @@ -44,10 +45,6 @@ public class TestFixedWidthOutOfPlaceVectorSorter extends TestOutOfPlaceVectorSo

private BufferAllocator allocator;

public TestFixedWidthOutOfPlaceVectorSorter(boolean generalSorter) {
super(generalSorter);
}

<V extends BaseFixedWidthVector> OutOfPlaceVectorSorter<V> getSorter() {
return generalSorter ? new GeneralOutOfPlaceVectorSorter<>() : new FixedWidthOutOfPlaceVectorSorter<>();
}
Expand All @@ -62,8 +59,10 @@ public void shutdown() {
allocator.close();
}

@Test
public void testSortByte() {
@ParameterizedTest
@MethodSource("getParameter")
public void testSortByte(boolean generalSorter) {
setup(generalSorter);
try (TinyIntVector vec = new TinyIntVector("", allocator)) {
vec.allocateNew(10);
vec.setValueCount(10);
Expand All @@ -85,7 +84,7 @@ public void testSortByte() {
VectorValueComparator<TinyIntVector> comparator = DefaultVectorComparators.createDefaultComparator(vec);

TinyIntVector sortedVec =
(TinyIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null);
(TinyIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null);
sortedVec.allocateNew(vec.getValueCount());
sortedVec.setValueCount(vec.getValueCount());

Expand All @@ -109,8 +108,10 @@ public void testSortByte() {
}
}

@Test
public void testSortShort() {
@ParameterizedTest
@MethodSource("getParameter")
public void testSortShort(boolean generalSorter) {
setup(generalSorter);
try (SmallIntVector vec = new SmallIntVector("", allocator)) {
vec.allocateNew(10);
vec.setValueCount(10);
Expand All @@ -132,7 +133,7 @@ public void testSortShort() {
VectorValueComparator<SmallIntVector> comparator = DefaultVectorComparators.createDefaultComparator(vec);

SmallIntVector sortedVec =
(SmallIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null);
(SmallIntVector) vec.getField().getFieldType().createNewSingleVector("", allocator, null);
sortedVec.allocateNew(vec.getValueCount());
sortedVec.setValueCount(vec.getValueCount());

Expand All @@ -156,8 +157,10 @@ public void testSortShort() {
}
}

@Test
public void testSortInt() {
@ParameterizedTest
@MethodSource("getParameter")
public void testSortInt(boolean generalSorter) {
setup(generalSorter);
try (IntVector vec = new IntVector("", allocator)) {
vec.allocateNew(10);
vec.setValueCount(10);
Expand Down Expand Up @@ -202,8 +205,10 @@ public void testSortInt() {
}
}

@Test
public void testSortLong() {
@ParameterizedTest
@MethodSource("getParameter")
public void testSortLong(boolean generalSorter) {
setup(generalSorter);
try (BigIntVector vec = new BigIntVector("", allocator)) {
vec.allocateNew(10);
vec.setValueCount(10);
Expand Down Expand Up @@ -248,8 +253,10 @@ public void testSortLong() {
}
}

@Test
public void testSortFloat() {
@ParameterizedTest
@MethodSource("getParameter")
public void testSortFloat(boolean generalSorter) {
setup(generalSorter);
try (Float4Vector vec = new Float4Vector("", allocator)) {
vec.allocateNew(10);
vec.setValueCount(10);
Expand Down Expand Up @@ -294,8 +301,10 @@ public void testSortFloat() {
}
}

@Test
public void testSortDouble() {
@ParameterizedTest
@MethodSource("getParameter")
public void testSortDouble(boolean generalSorter) {
setup(generalSorter);
try (Float8Vector vec = new Float8Vector("", allocator)) {
vec.allocateNew(10);
vec.setValueCount(10);
Expand Down Expand Up @@ -340,8 +349,10 @@ public void testSortDouble() {
}
}

@Test
public void testSortInt2() {
@ParameterizedTest
@MethodSource("getParameter")
public void testSortInt2(boolean generalSorter) {
setup(generalSorter);
try (IntVector vec = new IntVector("", allocator)) {
ValueVectorDataPopulator.setVector(vec,
0, 1, 2, 3, 4, 5, 30, 31, 32, 33,
Expand Down
Loading

0 comments on commit 31d8cf6

Please sign in to comment.