diff --git a/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java b/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java index d40f64844797..5d611fed6b82 100644 --- a/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java +++ b/core/src/main/java/org/apache/iceberg/util/ParallelIterable.java @@ -257,17 +257,17 @@ private static class Task implements Supplier>>, Closeable { @Override public Optional> get() { try { + if (queue.size() >= approximateMaxQueueSize) { + // Yield when queue is over the size limit. Task will be resubmitted later and continue + // the work. + return Optional.of(this); + } + if (iterator == null) { iterator = input.iterator(); } while (iterator.hasNext()) { - if (queue.size() >= approximateMaxQueueSize) { - // Yield when queue is over the size limit. Task will be resubmitted later and continue - // the work. - return Optional.of(this); - } - T next = iterator.next(); if (closed.get()) { break; diff --git a/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java b/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java index 5e37e0390db9..44bd8c371b46 100644 --- a/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java +++ b/core/src/test/java/org/apache/iceberg/util/TestParallelIterable.java @@ -25,6 +25,7 @@ import java.util.List; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.stream.IntStream; import java.util.stream.Stream; @@ -39,6 +40,7 @@ import org.apache.iceberg.util.ParallelIterable.ParallelIterator; import org.awaitility.Awaitility; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; public class TestParallelIterable { @Test @@ -148,7 +150,7 @@ public void limitQueueSize() { .collect(ImmutableMultiset.toImmutableMultiset()); int maxQueueSize = 20; - ExecutorService executor = Executors.newCachedThreadPool(); + ExecutorService executor = Executors.newSingleThreadExecutor(); ParallelIterable parallelIterable = new ParallelIterable<>(iterables, executor, maxQueueSize); ParallelIterator iterator = (ParallelIterator) parallelIterable.iterator(); @@ -158,7 +160,7 @@ public void limitQueueSize() { while (iterator.hasNext()) { assertThat(iterator.queueSize()) .as("iterator internal queue size") - .isLessThanOrEqualTo(maxQueueSize + iterables.size()); + .isLessThanOrEqualTo(100); actualValues.add(iterator.next()); } @@ -171,38 +173,48 @@ public void limitQueueSize() { } @Test - public void queueSizeOne() { - List> iterables = - ImmutableList.of( - () -> IntStream.range(0, 100).iterator(), - () -> IntStream.range(0, 100).iterator(), - () -> IntStream.range(0, 100).iterator()); + @Timeout(10) + public void noDeadlock() { + ExecutorService executor = Executors.newFixedThreadPool(1); + Semaphore semaphore = new Semaphore(1); - Multiset expectedValues = - IntStream.range(0, 100) - .boxed() - .flatMap(i -> Stream.of(i, i, i)) - .collect(ImmutableMultiset.toImmutableMultiset()); + List> iterablesA = ImmutableList.of(testIterable(semaphore::acquire, semaphore::release, IntStream.range(0, 100).iterator())); + List> iterablesB = ImmutableList.of(testIterable(semaphore::acquire, semaphore::release, IntStream.range(200, 300).iterator())); - ExecutorService executor = Executors.newCachedThreadPool(); - ParallelIterable parallelIterable = new ParallelIterable<>(iterables, executor, 1); - ParallelIterator iterator = (ParallelIterator) parallelIterable.iterator(); + ParallelIterable parallelIterableA = new ParallelIterable<>(iterablesA, executor, 1); + ParallelIterable parallelIterableB = new ParallelIterable<>(iterablesB, executor, 1); - Multiset actualValues = HashMultiset.create(); + parallelIterableA.iterator().next(); + parallelIterableB.iterator().next(); - while (iterator.hasNext()) { - assertThat(iterator.queueSize()) - .as("iterator internal queue size") - .isLessThanOrEqualTo(1 + iterables.size()); - actualValues.add(iterator.next()); - } + executor.shutdownNow(); + } - assertThat(actualValues) - .as("multiset of values returned by the iterator") - .isEqualTo(expectedValues); + private CloseableIterable testIterable(RunnableWithException open, RunnableWithException close, Iterator iterator) { + return new CloseableIterable() { + @Override + public void close() { + try { + close.run(); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public CloseableIterator iterator() { + try { + open.run(); + return CloseableIterator.withClose(iterator); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + }; + } - iterator.close(); - executor.shutdownNow(); + private interface RunnableWithException { + void run() throws Exception; } private void queueHasElements(ParallelIterator iterator) {