diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/DefaultLoadBalancer.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/DefaultLoadBalancer.java index 69e2ff6268..d656609a09 100644 --- a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/DefaultLoadBalancer.java +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/DefaultLoadBalancer.java @@ -18,14 +18,17 @@ import io.servicetalk.client.api.ConnectionFactory; import io.servicetalk.client.api.LoadBalancedConnection; import io.servicetalk.client.api.ServiceDiscovererEvent; +import io.servicetalk.concurrent.CompletableSource; import io.servicetalk.concurrent.PublisherSource.Processor; import io.servicetalk.concurrent.PublisherSource.Subscriber; import io.servicetalk.concurrent.PublisherSource.Subscription; import io.servicetalk.concurrent.api.Completable; import io.servicetalk.concurrent.api.CompositeCloseable; import io.servicetalk.concurrent.api.ListenableAsyncCloseable; +import io.servicetalk.concurrent.api.Processors; import io.servicetalk.concurrent.api.Publisher; import io.servicetalk.concurrent.api.Single; +import io.servicetalk.concurrent.api.SourceAdapters; import io.servicetalk.concurrent.internal.SequentialCancellable; import io.servicetalk.context.api.ContextMap; @@ -34,21 +37,13 @@ import java.util.ArrayList; import java.util.Collection; -import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; import java.util.List; -import java.util.ListIterator; import java.util.Map; import java.util.Map.Entry; -import java.util.Spliterator; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.AtomicLongFieldUpdater; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; -import java.util.function.Consumer; import java.util.function.Predicate; -import java.util.function.UnaryOperator; -import java.util.stream.Stream; import javax.annotation.Nullable; import static io.servicetalk.client.api.LoadBalancerReadyEvent.LOAD_BALANCER_NOT_READY_EVENT; @@ -81,10 +76,6 @@ final class DefaultLoadBalancer usedHostsUpdater = - AtomicReferenceFieldUpdater.newUpdater(DefaultLoadBalancer.class, List.class, "usedHosts"); - @SuppressWarnings("rawtypes") private static final AtomicLongFieldUpdater nextResubscribeTimeUpdater = AtomicLongFieldUpdater.newUpdater(DefaultLoadBalancer.class, "nextResubscribeTime"); @@ -92,9 +83,13 @@ final class DefaultLoadBalancer> usedHosts = emptyList(); + private volatile boolean isClosed; private final String targetResource; + private final SequentialExecutor sequentialExecutor; private final String lbDescription; private final HostSelector hostSelector; private final Publisher>> eventPublisher; @@ -137,28 +132,12 @@ final class DefaultLoadBalancer { - discoveryCancellable.cancel(); - eventStreamProcessor.onComplete(); - final CompositeCloseable compositeCloseable; - for (;;) { - List> currentList = usedHosts; - if (isClosedList(currentList) || - usedHostsUpdater.compareAndSet(this, currentList, new ClosedList<>(currentList))) { - compositeCloseable = newCompositeCloseable().appendAll(currentList).appendAll(connectionFactory); - LOGGER.debug("{} is closing {}gracefully. Last seen addresses (size={}): {}.", - this, graceful ? "" : "non", currentList.size(), currentList); - break; - } - } - return (graceful ? compositeCloseable.closeAsyncGracefully() : compositeCloseable.closeAsync()) - .beforeOnError(t -> { - if (!graceful) { - usedHosts = new ClosedList<>(emptyList()); - } - }) - .beforeOnComplete(() -> usedHosts = new ClosedList<>(emptyList())); + this.sequentialExecutor = new SequentialExecutor((uncaughtException) -> { + LOGGER.error("{}: Uncaught exception in SequentialExecutor triggered closing of the load balancer.", + this, uncaughtException); + closeAsync().subscribe(); }); + this.asyncCloseable = toAsyncCloseable(this::doClose); // Maintain a Subscriber so signals are always delivered to replay and new Subscribers get the latest signal. eventStream.ignoreElements().subscribe(); subscribeToEvents(false); @@ -178,6 +157,38 @@ private void subscribeToEvents(boolean resubscribe) { } } + // This method is called eagerly, meaning the completable will be immediately subscribed to, + // so we don't need to do any Completable.defer business. + private Completable doClose(final boolean graceful) { + CompletableSource.Processor processor = Processors.newCompletableProcessor(); + sequentialExecutor.execute(() -> { + try { + if (!isClosed) { + discoveryCancellable.cancel(); + eventStreamProcessor.onComplete(); + } + isClosed = true; + List> currentList = usedHosts; + final CompositeCloseable compositeCloseable = newCompositeCloseable() + .appendAll(currentList) + .appendAll(connectionFactory); + LOGGER.debug("{} is closing {}gracefully. Last seen addresses (size={}): {}.", + this, graceful ? "" : "non", currentList.size(), currentList); + SourceAdapters.toSource((graceful ? compositeCloseable.closeAsyncGracefully() : + // We only want to empty the host list on error if we're closing non-gracefully. + compositeCloseable.closeAsync().beforeOnError(t -> + sequentialExecutor.execute(() -> usedHosts = emptyList())) + ) + // we want to always empty out the host list if we complete successfully + .beforeOnComplete(() -> sequentialExecutor.execute(() -> usedHosts = emptyList()))) + .subscribe(processor); + } catch (Throwable ex) { + processor.onError(ex); + } + }); + return SourceAdapters.fromSource(processor); + } + private static long nextResubscribeTime( final HealthCheckConfig config, final DefaultLoadBalancer lb) { final long lowerNanos = config.healthCheckResubscribeLowerBound; @@ -253,81 +264,71 @@ public void onNext(@Nullable final Collection sequentialOnNext(events)); + } - boolean sendReadyEvent; - List> nextHosts; - for (;;) { - // TODO: we have some weirdness in the event that we fail the CAS namely that we can create a host - // that never gets used but is orphaned. It's fine so long as there is nothing to close but that - // guarantee may not always hold in the future. - @SuppressWarnings("unchecked") - List> usedHosts = usedHostsUpdater.get(DefaultLoadBalancer.this); - if (isClosedList(usedHosts)) { - // We don't update if the load balancer is closed. - return; - } - nextHosts = new ArrayList<>(usedHosts.size() + events.size()); - sendReadyEvent = false; - - // First we make a map of addresses to events so that we don't get quadratic behavior for diffing. - // Unfortunately we need to make this every iteration of the CAS loop since we remove entries - // for hosts that already exist. If this results in to many collisions and map rebuilds we should - // re-assess how we manage concurrency for list mutations. - final Map> eventMap = new HashMap<>(); - for (ServiceDiscovererEvent event : events) { - ServiceDiscovererEvent old = eventMap.put(event.address(), event); - if (old != null) { - LOGGER.debug("Multiple ServiceDiscoveryEvent's detected for address {}. Event: {}.", - event.address(), event); - } + private void sequentialOnNext(Collection> events) { + assert events != null && !events.isEmpty(); + + if (isClosed) { + // nothing to do if the load balancer is closed. + return; + } + + boolean sendReadyEvent = false; + final List> nextHosts = new ArrayList<>(usedHosts.size() + events.size()); + final List> oldUsedHosts = usedHosts; + // First we make a map of addresses to events so that we don't get quadratic behavior for diffing. + final Map> eventMap = new HashMap<>(); + for (ServiceDiscovererEvent event : events) { + ServiceDiscovererEvent old = eventMap.put(event.address(), event); + if (old != null) { + LOGGER.debug("Multiple ServiceDiscoveryEvent's detected for address {}. Event: {}.", + event.address(), event); } + } - // First thing we do is go through the existing hosts and see if we need to transfer them. These - // will be all existing hosts that either don't have a matching discovery event or are not marked - // as unavailable. If they are marked unavailable, we need to close them (which is idempotent). - for (Host host : usedHosts) { - ServiceDiscovererEvent event = eventMap.remove(host.address()); - if (event == null) { - // Host doesn't have a SD update so just copy it over. + // First thing we do is go through the existing hosts and see if we need to transfer them. These + // will be all existing hosts that either don't have a matching discovery event or are not marked + // as unavailable. If they are marked unavailable, we need to close them. + for (Host host : oldUsedHosts) { + ServiceDiscovererEvent event = eventMap.remove(host.address()); + if (event == null) { + // Host doesn't have a SD update so just copy it over. + nextHosts.add(host); + } else if (AVAILABLE.equals(event.status())) { + // We only send the ready event if the previous host list was empty. + sendReadyEvent = oldUsedHosts.isEmpty(); + // If the host is already in CLOSED state, we should discard it and create a new entry. + // For duplicate ACTIVE events the marking succeeds, so we will not add a new entry. + if (host.markActiveIfNotClosed()) { nextHosts.add(host); - } else if (AVAILABLE.equals(event.status())) { - // We only send the ready event if the previous host list was empty. - sendReadyEvent = usedHosts.isEmpty(); - // If the host is already in CLOSED state, we should discard it and create a new entry. - // For duplicate ACTIVE events or for repeated activation due to failed CAS - // of replacing the usedHosts array the marking succeeds so we will not add a new entry. - if (host.markActiveIfNotClosed()) { - nextHosts.add(host); - } else { - nextHosts.add(createHost(event.address())); - } - } else if (EXPIRED.equals(event.status())) { - if (!host.markExpired()) { - nextHosts.add(host); - } - } else if (UNAVAILABLE.equals(event.status())) { - host.markClosed(); } else { - LOGGER.warn("{}: Unsupported Status in event:" + - " {} (mapped to {}). Leaving usedHosts unchanged: {}", - DefaultLoadBalancer.this, event, event.status(), nextHosts); - nextHosts.add(host); - } - } - // Now process events that didn't have an existing host. The only ones that we actually care - // about are the AVAILABLE events which result in a new host. - for (ServiceDiscovererEvent event : eventMap.values()) { - if (AVAILABLE.equals(event.status())) { - sendReadyEvent = true; nextHosts.add(createHost(event.address())); } + } else if (EXPIRED.equals(event.status())) { + if (!host.markExpired()) { + nextHosts.add(host); + } + } else if (UNAVAILABLE.equals(event.status())) { + host.markClosed(); + } else { + LOGGER.warn("{}: Unsupported Status in event:" + + " {} (mapped to {}). Leaving usedHosts unchanged: {}", + DefaultLoadBalancer.this, event, event.status(), nextHosts); + nextHosts.add(host); } - // We've now built the new list so now we need to CAS it before we can move on. This should only be - // racing with closing hosts and closing the whole LB so it shouldn't be common to lose the race. - if (usedHostsUpdater.compareAndSet(DefaultLoadBalancer.this, usedHosts, nextHosts)) { - break; + } + // Now process events that didn't have an existing host. The only ones that we actually care + // about are the AVAILABLE events which result in a new host. + for (ServiceDiscovererEvent event : eventMap.values()) { + if (AVAILABLE.equals(event.status())) { + sendReadyEvent = true; + nextHosts.add(createHost(event.address())); } } + // We've built the new list so now set it for consumption and then send our events. + usedHosts = nextHosts; LOGGER.debug("{}: now using addresses (size={}): {}.", DefaultLoadBalancer.this, nextHosts.size(), nextHosts); @@ -367,13 +368,20 @@ private Host createHost(ResolvedAddress addr) { Host host = new DefaultHost<>(DefaultLoadBalancer.this.toString(), addr, connectionFactory, linearSearchSpace, healthCheckConfig); host.onClose().afterFinally(() -> - usedHostsUpdater.updateAndGet(DefaultLoadBalancer.this, previousHosts -> { - @SuppressWarnings("unchecked") - List> previousHostsTyped = - (List>) previousHosts; - return listWithHostRemoved(previousHostsTyped, current -> current == host); - } - )).subscribe(); + sequentialExecutor.execute(() -> { + final List> currentHosts = usedHosts; + if (currentHosts.isEmpty()) { + // Can't remove an entry from an empty list. + return; + } + final List> nextHosts = listWithHostRemoved( + currentHosts, current -> current == host); + usedHosts = nextHosts; + if (nextHosts.isEmpty()) { + // We transitioned from non-empty to empty. That means we're not ready. + eventStreamProcessor.onNext(LOAD_BALANCER_NOT_READY_EVENT); + } + })).subscribe(); return host; } @@ -445,7 +453,7 @@ private Single selectConnection0(final Predicate selector, @Nullable final // It's possible that we're racing with updates from the `onNext` method but since it's intrinsically // racy it's fine to do these 'are there any hosts at all' checks here using the total host set. if (currentHosts.isEmpty()) { - return isClosedList(currentHosts) ? failedLBClosed(targetResource) : + return isClosed ? failedLBClosed(targetResource) : // This is the case when SD has emitted some items but none of the hosts are available. failed(Exceptions.StacklessNoAvailableHostException.newInstance( "No hosts are available to connect for " + targetResource + ".", @@ -503,172 +511,10 @@ public List>> usedAddresses() { return usedHosts.stream().map(host -> ((DefaultHost) host).asEntry()).collect(toList()); } - private static boolean isClosedList(List list) { - return list.getClass().equals(ClosedList.class); - } - private String makeDescription(String id, String targetResource) { return getClass().getSimpleName() + "{" + "id=" + id + '@' + toHexString(identityHashCode(this)) + ", targetResource=" + targetResource + '}'; } - - private static final class ClosedList implements List { - private final List delegate; - - private ClosedList(final List delegate) { - this.delegate = requireNonNull(delegate); - } - - @Override - public int size() { - return delegate.size(); - } - - @Override - public boolean isEmpty() { - return delegate.isEmpty(); - } - - @Override - public boolean contains(final Object o) { - return delegate.contains(o); - } - - @Override - public Iterator iterator() { - return delegate.iterator(); - } - - @Override - public void forEach(final Consumer action) { - delegate.forEach(action); - } - - @Override - public Object[] toArray() { - return delegate.toArray(); - } - - @Override - public T1[] toArray(final T1[] a) { - return delegate.toArray(a); - } - - @Override - public boolean add(final T t) { - return delegate.add(t); - } - - @Override - public boolean remove(final Object o) { - return delegate.remove(o); - } - - @Override - public boolean containsAll(final Collection c) { - return delegate.containsAll(c); - } - - @Override - public boolean addAll(final Collection c) { - return delegate.addAll(c); - } - - @Override - public boolean addAll(final int index, final Collection c) { - return delegate.addAll(c); - } - - @Override - public boolean removeAll(final Collection c) { - return delegate.removeAll(c); - } - - @Override - public boolean removeIf(final Predicate filter) { - return delegate.removeIf(filter); - } - - @Override - public boolean retainAll(final Collection c) { - return delegate.retainAll(c); - } - - @Override - public void replaceAll(final UnaryOperator operator) { - delegate.replaceAll(operator); - } - - @Override - public void sort(final Comparator c) { - delegate.sort(c); - } - - @Override - public void clear() { - delegate.clear(); - } - - @Override - public T get(final int index) { - return delegate.get(index); - } - - @Override - public T set(final int index, final T element) { - return delegate.set(index, element); - } - - @Override - public void add(final int index, final T element) { - delegate.add(index, element); - } - - @Override - public T remove(final int index) { - return delegate.remove(index); - } - - @Override - public int indexOf(final Object o) { - return delegate.indexOf(o); - } - - @Override - public int lastIndexOf(final Object o) { - return delegate.lastIndexOf(o); - } - - @Override - public ListIterator listIterator() { - return delegate.listIterator(); - } - - @Override - public ListIterator listIterator(final int index) { - return delegate.listIterator(index); - } - - @Override - public List subList(final int fromIndex, final int toIndex) { - return new ClosedList<>(delegate.subList(fromIndex, toIndex)); - } - - @Override - public Spliterator spliterator() { - return delegate.spliterator(); - } - - @Override - public Stream stream() { - return delegate.stream(); - } - - @Override - public Stream parallelStream() { - return delegate.parallelStream(); - } - } } diff --git a/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/SequentialExecutor.java b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/SequentialExecutor.java new file mode 100644 index 0000000000..6ab6222c9c --- /dev/null +++ b/servicetalk-loadbalancer/src/main/java/io/servicetalk/loadbalancer/SequentialExecutor.java @@ -0,0 +1,111 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.loadbalancer; + +import io.servicetalk.concurrent.api.AsyncContext; + +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicReference; +import javax.annotation.Nullable; + +import static java.util.Objects.requireNonNull; + +/** + * A concurrency primitive for providing thread safety without using locks. + * + * A {@link SequentialExecutor} is queue of tasks that are executed one at a time in the order they were + * received. This provides a way to serialize work between threads without needing to use locks which can + * result in thread contention and thread deadlock scenarios. + */ +final class SequentialExecutor implements Executor { + + /** + * Handler of exceptions thrown by submitted Runnables. + */ + @FunctionalInterface + public interface ExceptionHandler { + + /** + * Handle the exception thrown from a submitted Runnable. + * Note that if this method throws the behavior is undefined. + * + * @param ex the Throwable thrown by the Runnable. + */ + void onException(Throwable ex); + } + + private final ExceptionHandler exceptionHandler; + private final AtomicReference tail = new AtomicReference<>(); + + SequentialExecutor(final ExceptionHandler exceptionHandler) { + this.exceptionHandler = requireNonNull(exceptionHandler, "exceptionHandler"); + } + + @Override + public void execute(Runnable command) { + // Make sure we propagate any sync contexts. + command = AsyncContext.wrapRunnable(requireNonNull(command, "command")); + final Cell next = new Cell(command); + Cell t = tail.getAndSet(next); + if (t != null) { + // Execution already started. Link the old tail to the new tail. + t.next = next; + } else { + // We are the first element in the queue so it's our responsibility to drain. + // Note that the getAndSet establishes the happens before with relation to the previous draining + // threads since we must successfully perform a CAS operation to terminate draining. + drain(next); + } + } + + private void drain(Cell next) { + for (;;) { + assert next != null; + try { + next.runnable.run(); + } catch (Throwable ex) { + exceptionHandler.onException(ex); + } + + // Attempt to get the next element. + Cell n = next.next; + if (n == null) { + // There doesn't seem to be another element linked. See if it was the tail and if so terminate draining. + // Note that a successful CAS established a happens-before relationship with future draining threads. + if (tail.compareAndSet(next, null)) { + break; + } + // next isn't the tail but the link hasn't resolved: we must poll until it does. + while ((n = next.next) == null) { + // Still not resolved: yield and then try again. + Thread.yield(); + } + } + next = n; + } + } + + private static final class Cell { + + final Runnable runnable; + @Nullable + volatile Cell next; + + Cell(Runnable runnable) { + this.runnable = runnable; + } + } +} diff --git a/servicetalk-loadbalancer/src/test/java/io/servicetalk/loadbalancer/SequentialExecutorTest.java b/servicetalk-loadbalancer/src/test/java/io/servicetalk/loadbalancer/SequentialExecutorTest.java new file mode 100644 index 0000000000..8720dae258 --- /dev/null +++ b/servicetalk-loadbalancer/src/test/java/io/servicetalk/loadbalancer/SequentialExecutorTest.java @@ -0,0 +1,199 @@ +/* + * Copyright © 2023 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.loadbalancer; + +import io.servicetalk.concurrent.api.AsyncContext; +import io.servicetalk.context.api.ContextMap; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +class SequentialExecutorTest { + + + private SequentialExecutor.ExceptionHandler exceptionHandler; + private Executor executor; + + @BeforeEach + void setup() { + exceptionHandler = (ignored) -> { }; + executor = new SequentialExecutor(exceptionHandler); + } + + @Test + void tasksAreExecuted() throws InterruptedException { + CountDownLatch latch = new CountDownLatch(2); + // submit two tasks and they should both complete. + executor.execute(() -> latch.countDown()); + executor.execute(() -> latch.countDown()); + latch.await(); + } + + @Test + void firstTaskIsExecutedByCallingThread() { + AtomicReference executorThread = new AtomicReference<>(); + executor.execute(() -> executorThread.set(Thread.currentThread())); + assertNotNull(executorThread.get()); + assertEquals(Thread.currentThread(), executorThread.get()); + } + + @Test + void thrownExceptionsArePropagatedToTheExceptionHandler() { + AtomicReference caught = new AtomicReference<>(); + exceptionHandler = caught::set; + executor = new SequentialExecutor(exceptionHandler); + final RuntimeException ex = new RuntimeException("expected"); + executor.execute(() -> { + throw ex; + }); + assertEquals(ex, caught.get()); + } + + @Test + void queuedTasksAreExecuted() throws InterruptedException { + final CountDownLatch l1 = new CountDownLatch(1); + final CountDownLatch l2 = new CountDownLatch(1); + Thread t = new Thread(() -> + executor.execute(() -> { + try { + l1.countDown(); + l2.await(); + } catch (Exception ex) { + throw new AssertionError("Unexpected failure", ex); + } + })); + t.start(); + + // wait for t1 to be in the execution loop then submit a task that should be queued. + l1.await(); + + // note that the behavior of the initial submitting thread executing queued tasks is not critical to the + // primitive: we could envision another correct implementation where a submitter will execute the task it just + // submitted but if there are additional tasks the work gets shifted to a pooled thread to drain. If we switch + // the model, the test should be adjusted to conform to the desired behavior. + final AtomicReference executingThread = new AtomicReference<>(); + executor.execute(() -> executingThread.set(Thread.currentThread())); + assertNull(executingThread.get()); + + // Now unblock the initial thread and it should also run the second task. + l2.countDown(); + t.join(); + assertEquals(t, executingThread.get()); + } + + @Test + void tasksAreNotRenentrant() { + Queue order = new ArrayDeque<>(); + executor.execute(() -> { + // this should be queued for later. + executor.execute(() -> order.add(2)); + order.add(1); + }); + + assertThat(order, contains(1, 2)); + } + + @Test + void noStackOverflows() throws Exception { + final int maxDepth = 10_000; + // If we substitute `executor` with `(runnable) -> runnable.run()` we get a stack overflow. + final Runnable runnable = new Runnable() { + private final AtomicInteger depth = new AtomicInteger(); + @Override + public void run() { + if (depth.incrementAndGet() < maxDepth) { + executor.execute(this); + } + } + }; + // kick it off. We don't expect any stack-overflows from `SequentialExecutor` which should + // always queue the tasks therefore trading stack space for heap space. + executor.execute(runnable); + } + + @Test + void manyThreadsCanSubmitTasksConcurrently() throws InterruptedException { + final int threadCount = 100; + CountDownLatch completed = new CountDownLatch(threadCount); + CountDownLatch ready = new CountDownLatch(threadCount); + CountDownLatch barrier = new CountDownLatch(1); + + for (int i = 0; i < threadCount; i++) { + Thread t = new Thread(() -> { + try { + ready.countDown(); + barrier.await(); + executor.execute(() -> completed.countDown()); + } catch (Exception ex) { + throw new AssertionError("unexpected error", ex); + } + }); + t.start(); + } + // wait for all the threads to have started + ready.await(); + // release all the threads to submit their work to the executor. + barrier.countDown(); + // all tasks should have completed. Note that all thread are racing with each other to + // submit work so the order of work execution isn't important. + completed.await(); + } + + @Test + void preservesAsyncContext() throws InterruptedException { + final CountDownLatch l1 = new CountDownLatch(1); + final CountDownLatch l2 = new CountDownLatch(1); + // setup a thread to enter the executor and start executing so we can submit another + // task from the test runner thread that shouldn't have the same AsyncContext. + Thread t = new Thread(() -> + executor.execute(() -> { + try { + l1.countDown(); + l2.await(); + } catch (Exception ex) { + throw new AssertionError("Unexpected failure", ex); + } + })); + t.start(); + l1.await(); + + final AtomicReference observedContextValue = new AtomicReference<>(); + final ContextMap.Key key = ContextMap.Key.newKey("testkey", Object.class); + final Object value = new Object(); + + AsyncContext.put(key, value); + executor.execute(() -> observedContextValue.set(AsyncContext.context().get(key))); + assertNull(observedContextValue.get()); + + // Now unblock the initial thread and it should also run the second task. + l2.countDown(); + t.join(); + assertEquals(value, observedContextValue.get()); + } +}