Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Publisher.replay
Browse files Browse the repository at this point in the history
Motivation:
Publisher.replay provides the ability to keep state that is
preserved for multiple subscribers and across resubscribes.
Scottmitch committed Sep 19, 2023
1 parent 3126482 commit 8d3d602
Showing 10 changed files with 1,380 additions and 112 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -56,13 +56,17 @@
import static io.servicetalk.concurrent.api.EmptyPublisher.emptyPublisher;
import static io.servicetalk.concurrent.api.Executors.global;
import static io.servicetalk.concurrent.api.FilterPublisher.newDistinctSupplier;
import static io.servicetalk.concurrent.api.MulticastPublisher.DEFAULT_MULTICAST_QUEUE_LIMIT;
import static io.servicetalk.concurrent.api.MulticastPublisher.DEFAULT_MULTICAST_TERM_RESUB;
import static io.servicetalk.concurrent.api.MulticastPublisher.newMulticastPublisher;
import static io.servicetalk.concurrent.api.NeverPublisher.neverPublisher;
import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnCancelSupplier;
import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnCompleteSupplier;
import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnErrorSupplier;
import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnNextSupplier;
import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnRequestSupplier;
import static io.servicetalk.concurrent.api.PublisherDoOnUtils.doOnSubscribeSupplier;
import static io.servicetalk.concurrent.api.ReplayPublisher.newReplayPublisher;
import static io.servicetalk.concurrent.internal.SubscriberUtils.deliverErrorFromSource;
import static io.servicetalk.utils.internal.DurationUtils.toNanos;
import static java.util.Objects.requireNonNull;
@@ -2991,7 +2995,7 @@ public final <Key> Publisher<GroupedPublisher<Key, T>> groupToMany(
*/
@Deprecated
public final Publisher<T> multicastToExactly(int expectedSubscribers) {
return multicastToExactly(expectedSubscribers, 64);
return multicastToExactly(expectedSubscribers, DEFAULT_MULTICAST_QUEUE_LIMIT);
}

/**
@@ -3023,7 +3027,7 @@ public final Publisher<T> multicastToExactly(int expectedSubscribers) {
*/
@Deprecated
public final Publisher<T> multicastToExactly(int expectedSubscribers, int queueLimit) {
return new MulticastPublisher<>(this, expectedSubscribers, true, true, queueLimit, t -> completed());
return newMulticastPublisher(this, expectedSubscribers, true, true, queueLimit, t -> completed());
}

/**
@@ -3082,7 +3086,7 @@ public final Publisher<T> multicast(int minSubscribers) {
* @see <a href="https://reactivex.io/documentation/operators/publish.html">ReactiveX multicast operator</a>
*/
public final Publisher<T> multicast(int minSubscribers, boolean cancelUpstream) {
return multicast(minSubscribers, 64, cancelUpstream);
return multicast(minSubscribers, DEFAULT_MULTICAST_QUEUE_LIMIT, cancelUpstream);
}

/**
@@ -3145,7 +3149,7 @@ public final Publisher<T> multicast(int minSubscribers, int queueLimit) {
* @see <a href="https://reactivex.io/documentation/operators/publish.html">ReactiveX multicast operator</a>
*/
public final Publisher<T> multicast(int minSubscribers, int queueLimit, boolean cancelUpstream) {
return multicast(minSubscribers, queueLimit, cancelUpstream, t -> completed());
return multicast(minSubscribers, queueLimit, cancelUpstream, DEFAULT_MULTICAST_TERM_RESUB);
}

/**
@@ -3224,7 +3228,73 @@ public final Publisher<T> multicast(int minSubscribers, int queueLimit,
*/
public final Publisher<T> multicast(int minSubscribers, int queueLimit, boolean cancelUpstream,
Function<Throwable, Completable> terminalResubscribe) {
return new MulticastPublisher<>(this, minSubscribers, false, cancelUpstream, queueLimit, terminalResubscribe);
return newMulticastPublisher(this, minSubscribers, false, cancelUpstream, queueLimit, terminalResubscribe);
}

/**
* Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned
* {@link Publisher} but also retains {@code history} of the most recently emitted signals from
* {@link Subscriber#onNext(Object)} which are emitted to new downstream {@link Subscriber}s before emitting new
* signals.
* @param history max number of items to retain which can be delivered to new subscribers.
* @return A {@link Publisher} that allows for multiple downstream subscribers and emits the previous
* {@code history} {@link Subscriber#onNext(Object)} signals to each new subscriber.
* @see <a href="https://reactivex.io/documentation/operators/replay.html">ReactiveX replay operator</a>
* @see ReplayStrategies#historyBuilder(int)
* @see #replay(ReplayStrategy)
*/
public final Publisher<T> replay(int history) {
return replay(ReplayStrategies.<T>historyBuilder(history).build());
}

/**
* Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned
* {@link Publisher} but also retains {@code history} of the most recently emitted signals
* from {@link Subscriber#onNext(Object)} which are emitted to new downstream {@link Subscriber}s before emitting
* new signals. Each item is only retained for {@code ttl} duration of time.
* @param history max number of items to retain which can be delivered to new subscribers.
* @param ttl duration each element will be retained before being removed.
* @param executor used to enforce the {@code ttl} argument.
* @return A {@link Publisher} that allows for multiple downstream subscribers and emits the previous
* {@code history} {@link Subscriber#onNext(Object)} signals to each new subscriber.
* @see <a href="https://reactivex.io/documentation/operators/replay.html">ReactiveX replay operator</a>
* @see ReplayStrategies#historyTtlBuilder(int, Duration, io.servicetalk.concurrent.Executor)
* @see #replay(ReplayStrategy)
*/
public final Publisher<T> replay(int history, Duration ttl, io.servicetalk.concurrent.Executor executor) {
return replay(ReplayStrategies.<T>historyTtlBuilder(history, ttl, executor).build());
}

/**
* Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned
* {@link Publisher} but will also retain some history of {@link Subscriber#onNext(Object)} signals
* according to the {@link ReplayAccumulator} {@code accumulatorSupplier}.
* @param accumulatorSupplier supplies a {@link ReplayAccumulator} on each subscribe to upstream that can retain
* history of {@link Subscriber#onNext(Object)} signals to deliver to new downstream subscribers.
* @return A {@link Publisher} that allows for multiple downstream subscribers that can retain
* history of {@link Subscriber#onNext(Object)} signals to deliver to new downstream subscribers.
* @see <a href="https://reactivex.io/documentation/operators/replay.html">ReactiveX replay operator</a>
* @see #replay(ReplayStrategy)
*/
public final Publisher<T> replay(Supplier<ReplayAccumulator<T>> accumulatorSupplier) {
return replay(new ReplayStrategyBuilder<>(accumulatorSupplier).build());
}

/**
* Similar to {@link #multicast(int)} in that multiple downstream {@link Subscriber}s are enabled on the returned
* {@link Publisher} but will also retain some history of {@link Subscriber#onNext(Object)} signals
* according to the {@link ReplayStrategy} {@code replayStrategy}.
* @param replayStrategy a {@link ReplayStrategy} that determines the replay behavior and history retention logic.
* @return A {@link Publisher} that allows for multiple downstream subscribers that can retain
* history of {@link Subscriber#onNext(Object)} signals to deliver to new downstream subscribers.
* @see <a href="https://reactivex.io/documentation/operators/replay.html">ReactiveX replay operator</a>
* @see ReplayStrategyBuilder
* @see ReplayStrategies
*/
public final Publisher<T> replay(ReplayStrategy<T> replayStrategy) {
return newReplayPublisher(this, replayStrategy.accumulatorSupplier(), replayStrategy.minSubscribers(),
replayStrategy.cancelUpstream(), replayStrategy.queueLimitHint(),
replayStrategy.terminalResubscribe());
}

/**
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.PublisherSource.Subscriber;

import java.util.function.Consumer;
import javax.annotation.Nullable;

/**
* Accumulates signals for the {@link Publisher} replay operator.
* @param <T> The type of data to accumulate.
*/
public interface ReplayAccumulator<T> {
/**
* Called on each {@link Subscriber#onNext(Object)} and intended to accumulate the signal so that new
* {@link Subscriber}s will see this value via {@link #deliverAccumulation(Consumer)}.
* <p>
* This method won't be called concurrently, but should return quickly to minimize performance impacts.
* @param t An {@link Subscriber#onNext(Object)} to accumulate.
*/
void accumulate(@Nullable T t);

/**
* Called to deliver the signals from {@link #accumulate(Object)} to new {@code consumer}.
* @param consumer The consumer of the signals previously aggregated via {@link #accumulate(Object)}.
*/
void deliverAccumulation(Consumer<T> consumer);

/**
* Called if the accumulation can be cancelled and any asynchronous resources can be cleaned up (e.g. timers).
*/
default void cancelAccumulation() {
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.internal.TerminalNotification;

import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked;
import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull;
import static io.servicetalk.concurrent.internal.ConcurrentUtils.releaseLock;
import static io.servicetalk.concurrent.internal.ConcurrentUtils.tryAcquireLock;
import static io.servicetalk.concurrent.internal.SubscriberUtils.safeOnComplete;
import static io.servicetalk.concurrent.internal.SubscriberUtils.safeOnError;
import static io.servicetalk.concurrent.internal.TerminalNotification.complete;
import static io.servicetalk.concurrent.internal.TerminalNotification.error;
import static io.servicetalk.utils.internal.ThrowableUtils.addSuppressed;
import static java.util.Objects.requireNonNull;

final class ReplayPublisher<T> extends MulticastPublisher<T> {
@SuppressWarnings("rawtypes")
private static final AtomicLongFieldUpdater<ReplayPublisher.ReplayState> signalQueuedUpdater =
AtomicLongFieldUpdater.newUpdater(ReplayPublisher.ReplayState.class, "signalsQueued");
private final Supplier<ReplayAccumulator<T>> accumulatorSupplier;

private ReplayPublisher(
Publisher<T> original, Supplier<ReplayAccumulator<T>> accumulatorSupplier, int minSubscribers,
boolean cancelUpstream, int maxQueueSize, Function<Throwable, Completable> terminalResubscribe) {
super(original, minSubscribers, false, cancelUpstream, maxQueueSize, terminalResubscribe);
this.accumulatorSupplier = requireNonNull(accumulatorSupplier);
}

static <T> MulticastPublisher<T> newReplayPublisher(
Publisher<T> original, Supplier<ReplayAccumulator<T>> accumulatorSupplier, int minSubscribers,
boolean cancelUpstream, int maxQueueSize, Function<Throwable, Completable> terminalResubscribe) {
ReplayPublisher<T> publisher = new ReplayPublisher<>(original, accumulatorSupplier, minSubscribers,
cancelUpstream, minSubscribers, terminalResubscribe);
publisher.resetState(maxQueueSize, minSubscribers);
return publisher;
}

@Override
void resetState(int maxQueueSize, int minSubscribers) {
state = new ReplayState(maxQueueSize, minSubscribers, accumulatorSupplier.get());
}

private final class ReplayState extends MulticastPublisher<T>.State {
private final ReplayAccumulator<T> accumulator;
/**
* We could check {@link #subscriptionEvents} is empty, but there are events outside of {@link Subscriber}
* signals in this queue that we don't care about in terms of preserving order, so we keep this count instead
* to only queue when necessary.
*/
volatile long signalsQueued;

ReplayState(final int maxQueueSize, final int minSubscribers,
ReplayAccumulator<T> accumulator) {
super(maxQueueSize, minSubscribers);
this.accumulator = requireNonNull(accumulator);
}

@Override
public void onNext(@Nullable final T t) {
// signalsQueued must be 0 or else items maybe delivered out of order. The value will only be increased
// on the Subscriber thread (no concurrency) and decreased on the draining thread. Optimistically check
// the value here and worst case if the queue has been drained of signals and this thread hasn't yet
// observed the value we will queue but still see correct ordering.
if (signalsQueued == 0 && tryAcquireLock(subscriptionLockUpdater, this)) {
try {
// All subscribers must either see this direct onNext signal, or see it through the accumulator.
// Therefore, we accumulate and deliver onNext while locked to avoid either delivering the signal
// twice (accumulator, addSubscriber, and onNext) or not at all (missed due to concurrency).
accumulator.accumulate(t);
super.onNext(t);
} finally {
if (!releaseLock(subscriptionLockUpdater, this)) {
processSubscriptionEvents();
}
}
} else {
queueOnNext(t);
}
}

@Override
public void onError(final Throwable t) {
if (signalsQueued == 0 && tryAcquireLock(subscriptionLockUpdater, this)) {
try {
super.onError(t);
} finally {
if (!releaseLock(subscriptionLockUpdater, this)) {
processSubscriptionEvents();
}
}
} else {
queueTerminal(error(t));
}
}

@Override
public void onComplete() {
if (signalsQueued == 0 && tryAcquireLock(subscriptionLockUpdater, this)) {
try {
super.onComplete();
} finally {
if (!releaseLock(subscriptionLockUpdater, this)) {
processSubscriptionEvents();
}
}
} else {
queueTerminal(complete());
}
}

@Override
void processOnNextEvent(Object wrapped) {
// subscriptionLockUpdater is held
signalQueuedUpdater.decrementAndGet(this);
final T unwrapped = unwrapNullUnchecked(wrapped);
accumulator.accumulate(unwrapped);
super.onNext(unwrapped);
}

@Override
void processTerminal(TerminalNotification terminalNotification) {
// subscriptionLockUpdater is held
signalQueuedUpdater.decrementAndGet(this);
if (terminalNotification.cause() != null) {
super.onError(terminalNotification.cause());
} else {
super.onComplete();
}
}

@Override
boolean processSubscribeEvent(MulticastFixedSubscriber<T> subscriber,
@Nullable TerminalSubscriber<?> terminalSubscriber) {
// subscriptionLockUpdater is held
if (terminalSubscriber == null) {
// Only call the super class if no terminal event. We don't want the super class to terminate
// the subscriber because we need to deliver any accumulated signals, and we also don't want to
// track state in demandQueue because it isn't necessary to manage upstream demand, and we don't want
// to hold a reference to the subscriber unnecessarily.
super.processSubscribeEvent(subscriber, null);
}
Throwable caughtCause = null;
try {
// It's safe to call onNext before onSubscribe bcz the base class expects onSubscribe to be async and
// queues/reorders events to preserve ReactiveStreams semantics.
accumulator.deliverAccumulation(subscriber::onNext);
} catch (Throwable cause) {
caughtCause = cause;
} finally {
if (terminalSubscriber != null) {
if (caughtCause != null) {
if (terminalSubscriber.terminalError != null) {
// Use caughtCause as original otherwise we keep appending to the cached Throwable.
safeOnError(subscriber, addSuppressed(caughtCause, terminalSubscriber.terminalError));
} else {
safeOnError(subscriber, caughtCause);
}
} else if (terminalSubscriber.terminalError != null) {
safeOnError(subscriber, terminalSubscriber.terminalError);
} else {
safeOnComplete(subscriber);
}
} else if (caughtCause != null) {
safeOnError(subscriber, caughtCause);
}
}
// Even if we terminated we always want to continue processing to trigger onSubscriber and allow queued
// signals from above to be processed when demand arrives.
return true;
}

@Override
void upstreamCancelled() {
// subscriptionLockUpdater is held
accumulator.cancelAccumulation();
}

private void queueOnNext(@Nullable T t) {
signalQueuedUpdater.incrementAndGet(this);
subscriptionEvents.add(wrapNull(t));
processSubscriptionEvents();
}

private void queueTerminal(TerminalNotification terminalNotification) {
signalQueuedUpdater.incrementAndGet(this);
subscriptionEvents.add(terminalNotification);
processSubscriptionEvents();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.Cancellable;
import io.servicetalk.concurrent.Executor;

import java.time.Duration;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.Consumer;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.SubscriberApiUtils.unwrapNullUnchecked;
import static io.servicetalk.concurrent.api.SubscriberApiUtils.wrapNull;
import static io.servicetalk.concurrent.internal.EmptySubscriptions.EMPTY_SUBSCRIPTION_NO_THROW;
import static java.util.Objects.requireNonNull;
import static java.util.concurrent.TimeUnit.NANOSECONDS;
import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater;

/**
* Utilities to customize {@link ReplayStrategy}.
*/
public final class ReplayStrategies {
private ReplayStrategies() {
}

/**
* Create a {@link ReplayStrategyBuilder} using the history strategy.
* @param history max number of items to retain which can be delivered to new subscribers.
* @param <T> The type of {@link ReplayStrategyBuilder}.
* @return a {@link ReplayStrategyBuilder} using the history strategy.
*/
public static <T> ReplayStrategyBuilder<T> historyBuilder(int history) {
return new ReplayStrategyBuilder<>(() -> new MostRecentReplayAccumulator<>(history));
}

/**
* Create a {@link ReplayStrategyBuilder} using the history and TTL strategy.
* @param history max number of items to retain which can be delivered to new subscribers.
* @param ttl duration each element will be retained before being removed.
* @param executor used to enforce the {@code ttl} argument.
* @param <T> The type of {@link ReplayStrategyBuilder}.
* @return a {@link ReplayStrategyBuilder} using the history and TTL strategy.
*/
public static <T> ReplayStrategyBuilder<T> historyTtlBuilder(int history, Duration ttl, Executor executor) {
return new ReplayStrategyBuilder<>(() -> new MostRecentTimeLimitedReplayAccumulator<>(history, ttl, executor));
}

private static final class MostRecentReplayAccumulator<T> implements ReplayAccumulator<T> {
private final int maxItems;
private final Deque<Object> list = new ArrayDeque<>();

MostRecentReplayAccumulator(final int maxItems) {
if (maxItems <= 0) {
throw new IllegalArgumentException("maxItems: " + maxItems + "(expected >0)");
}
this.maxItems = maxItems;
}

@Override
public void accumulate(@Nullable final T t) {
if (list.size() >= maxItems) {
list.pop();
}
list.add(wrapNull(t));
}

@Override
public void deliverAccumulation(final Consumer<T> consumer) {
for (Object item : list) {
consumer.accept(unwrapNullUnchecked(item));
}
}
}

private static final class MostRecentTimeLimitedReplayAccumulator<T> implements ReplayAccumulator<T> {
@SuppressWarnings("rawtypes")
private static final AtomicLongFieldUpdater<MostRecentTimeLimitedReplayAccumulator> stateSizeUpdater =
AtomicLongFieldUpdater.newUpdater(MostRecentTimeLimitedReplayAccumulator.class, "stateSize");
@SuppressWarnings("rawtypes")
private static final AtomicReferenceFieldUpdater<MostRecentTimeLimitedReplayAccumulator, Cancellable>
timerCancellableUpdater = newUpdater(MostRecentTimeLimitedReplayAccumulator.class, Cancellable.class,
"timerCancellable");
private final Executor executor;
private final Queue<TimeStampSignal<T>> items;
private final long ttlNanos;
private final int maxItems;
/**
* Provide atomic state for size of {@link #items} and also for visibility between the threads consuming and
* producing. The atomically incrementing "state" ensures that any modifications from the producer thread
* are visible from the consumer thread and we never "miss" a timer schedule event if the queue becomes empty.
*/
private volatile long stateSize;
@Nullable
private volatile Cancellable timerCancellable;

MostRecentTimeLimitedReplayAccumulator(final int maxItems, final Duration ttl, final Executor executor) {
if (ttl.isNegative()) {
throw new IllegalArgumentException("ttl: " + ttl + "(expected non-negative)");
}
if (maxItems <= 0) {
throw new IllegalArgumentException("maxItems: " + maxItems + "(expected >0)");
}
this.executor = requireNonNull(executor);
this.ttlNanos = ttl.toNanos();
this.maxItems = maxItems;
items = new ConcurrentLinkedQueue<>(); // SpMc
}

@Override
public void accumulate(@Nullable final T t) {
// We may exceed max items in the queue but this method isn't invoked concurrently, so we only go over by
// at most 1 item.
items.add(new TimeStampSignal<>(executor.currentTime(NANOSECONDS), t));
for (;;) {
final long currentStateSize = stateSize;
final int currentSize = getSize(currentStateSize);
final int nextState = getState(currentStateSize) + 1;
if (currentSize >= maxItems) {
if (stateSizeUpdater.compareAndSet(this, currentStateSize,
buildStateSize(nextState, currentSize))) {
items.poll();
break;
}
} else if (stateSizeUpdater.compareAndSet(this, currentStateSize,
buildStateSize(nextState, currentSize + 1))) {
if (currentSize == 0) {
schedulerTimer(ttlNanos);
}
break;
}
}
}

@Override
public void deliverAccumulation(final Consumer<T> consumer) {
for (TimeStampSignal<T> timeStampSignal : items) {
consumer.accept(timeStampSignal.signal);
}
}

@Override
public void cancelAccumulation() {
final Cancellable cancellable = timerCancellableUpdater.getAndSet(this, EMPTY_SUBSCRIPTION_NO_THROW);
if (cancellable != null) {
cancellable.cancel();
}
}

private static int getSize(long stateSize) {
return (int) stateSize;
}

private static int getState(long stateSize) {
return (int) (stateSize >>> 32);
}

private static long buildStateSize(int state, int size) {
return (((long) state) << 32) | size;
}

private void schedulerTimer(long nanos) {
for (;;) {
final Cancellable currentCancellable = timerCancellable;
if (currentCancellable == EMPTY_SUBSCRIPTION_NO_THROW) {
break;
} else {
final Cancellable nextCancellable = executor.schedule(this::expireSignals, nanos, NANOSECONDS);
if (timerCancellableUpdater.compareAndSet(this, currentCancellable, nextCancellable)) {
// Current logic only has 1 timer outstanding at any give time so cancellation of
// the current cancellable shouldn't be necessary but do it for completeness.
if (currentCancellable != null) {
currentCancellable.cancel();
}
break;
} else {
nextCancellable.cancel();
}
}
}
}

private void expireSignals() {
final long nanoTime = executor.currentTime(NANOSECONDS);
TimeStampSignal<T> item;
for (;;) {
// read stateSize before peek, so if we poll from the queue we are sure to see the correct
// state relative to items in the queue.
final long currentStateSize = stateSize;
item = items.peek();
if (item == null) {
break;
} else if (nanoTime - item.timeStamp >= ttlNanos) {
final int currentSize = getSize(currentStateSize);
if (stateSizeUpdater.compareAndSet(this, currentStateSize,
buildStateSize(getState(currentStateSize) + 1, currentSize - 1))) {
// When we add: we add to the queue we add first, then CAS sizeState.
// When we remove: we CAS the atomic state first, then poll.
// This avoids removing a non-expired item because if the "add" thread is running faster and
// already polled "item" the CAS will fail, and we will try again on the next loop iteration.
items.poll();
if (currentSize == 1) {
// a new timer task will be scheduled after addition if this is the case. break to avoid
// multiple timer tasks running concurrently.
break;
}
}
} else {
schedulerTimer(ttlNanos - (nanoTime - item.timeStamp));
break; // elements sorted in increasing time, break when first non-expired entry found.
}
}
}
}

private static final class TimeStampSignal<T> {
final long timeStamp;
@Nullable
final T signal;

private TimeStampSignal(final long timeStamp, @Nullable final T signal) {
this.timeStamp = timeStamp;
this.signal = signal;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.PublisherSource.Subscriber;

import java.util.function.Function;
import java.util.function.Supplier;

/**
* Used to customize the strategy for the {@link Publisher} replay operator.
* @param <T> The type of data.
*/
public interface ReplayStrategy<T> {
/**
* Get the minimum number of downstream subscribers before subscribing upstream.
* @return the minimum number of downstream subscribers before subscribing upstream.
*/
int minSubscribers();

/**
* Get a {@link Supplier} that provides the {@link ReplayAccumulator} on each upstream subscribe.
* @return a {@link Supplier} that provides the {@link ReplayAccumulator} on each upstream subscribe.
*/
Supplier<ReplayAccumulator<T>> accumulatorSupplier();

/**
* Determine if all the downstream subscribers cancel, should upstream be cancelled.
* @return {@code true} if all the downstream subscribers cancel, should upstream be cancelled. {@code false}
* will not cancel upstream if all downstream subscribers cancel.
*/
boolean cancelUpstream();

/**
* Get a hint to limit the number of elements which will be queued for each {@link Subscriber} in order to
* compensate for unequal demand and late subscribers.
* @return a hint to limit the number of elements which will be queued for each {@link Subscriber} in order to
* compensate for unequal demand and late subscribers.
*/
int queueLimitHint();

/**
* Get a {@link Function} that is invoked when a terminal signal arrives from upstream and determines when state
* is reset to allow for upstream resubscribe.
* @return A {@link Function} that is invoked when a terminal signal arrives from upstream, and
* returns a {@link Completable} whose termination resets the state of the returned {@link Publisher} and allows
* for downstream resubscribing. The argument to this function is as follows:
* <ul>
* <li>{@code null} if upstream terminates with {@link Subscriber#onComplete()}</li>
* <li>otherwise the {@link Throwable} from {@link Subscriber#onError(Throwable)}</li>
* </ul>
*/
Function<Throwable, Completable> terminalResubscribe();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.PublisherSource.Subscriber;

import java.util.function.Function;
import java.util.function.Supplier;

import static io.servicetalk.concurrent.api.Completable.never;
import static io.servicetalk.concurrent.api.MulticastPublisher.DEFAULT_MULTICAST_QUEUE_LIMIT;
import static java.util.Objects.requireNonNull;

/**
* A builder of {@link ReplayStrategy}.
* @param <T> The type of data for {@link ReplayStrategy}.
*/
public final class ReplayStrategyBuilder<T> {
private int minSubscribers = 1;
private final Supplier<ReplayAccumulator<T>> accumulatorSupplier;
private boolean cancelUpstream;
private int queueLimitHint = DEFAULT_MULTICAST_QUEUE_LIMIT;
private Function<Throwable, Completable> terminalResubscribe = t -> never();

/**
* Create a new instance.
* @param accumulatorSupplier provides the {@link ReplayAccumulator} to use on each subscribe to upstream.
*/
public ReplayStrategyBuilder(Supplier<ReplayAccumulator<T>> accumulatorSupplier) {
this.accumulatorSupplier = requireNonNull(accumulatorSupplier);
}

/**
* Set the minimum number of downstream subscribers before subscribing upstream.
* @param minSubscribers the minimum number of downstream subscribers before subscribing upstream.
* @return {@code this}.
*/
public ReplayStrategyBuilder<T> minSubscribers(int minSubscribers) {
if (minSubscribers <= 0) {
throw new IllegalArgumentException("minSubscribers: " + minSubscribers + " (expected >0)");
}
this.minSubscribers = minSubscribers;
return this;
}

/**
* Determine if all the downstream subscribers cancel, should upstream be cancelled.
* @param cancelUpstream {@code true} if all the downstream subscribers cancel, should upstream be cancelled.
* {@code false} will not cancel upstream if all downstream subscribers cancel.
* @return {@code this}.
*/
public ReplayStrategyBuilder<T> cancelUpstream(boolean cancelUpstream) {
this.cancelUpstream = cancelUpstream;
return this;
}

/**
* Set a hint to limit the number of elements which will be queued for each {@link Subscriber} in order to
* compensate for unequal demand and late subscribers.
* @param queueLimitHint a hint to limit the number of elements which will be queued for each {@link Subscriber} in
* order to compensate for unequal demand and late subscribers.
* @return {@code this}.
*/
public ReplayStrategyBuilder<T> queueLimitHint(int queueLimitHint) {
if (queueLimitHint < 1) {
throw new IllegalArgumentException("maxQueueSize: " + queueLimitHint + " (expected >1)");
}
this.queueLimitHint = queueLimitHint;
return this;
}

/**
* Set a {@link Function} that is invoked when a terminal signal arrives from upstream and determines when state
* is reset to allow for upstream resubscribe.
* @param terminalResubscribe A {@link Function} that is invoked when a terminal signal arrives from upstream, and
* returns a {@link Completable} whose termination resets the state of the returned {@link Publisher} and allows
* for downstream resubscribing. The argument to this function is as follows:
* <ul>
* <li>{@code null} if upstream terminates with {@link Subscriber#onComplete()}</li>
* <li>otherwise the {@link Throwable} from {@link Subscriber#onError(Throwable)}</li>
* </ul>
* @return {@code this}.
*/
public ReplayStrategyBuilder<T> terminalResubscribe(
Function<Throwable, Completable> terminalResubscribe) {
this.terminalResubscribe = requireNonNull(terminalResubscribe);
return this;
}

/**
* Build the {@link ReplayStrategy}.
* @return the {@link ReplayStrategy}.
*/
public ReplayStrategy<T> build() {
return new DefaultReplayStrategy<>(minSubscribers, accumulatorSupplier, cancelUpstream, queueLimitHint,
terminalResubscribe);
}

private static final class DefaultReplayStrategy<T> implements ReplayStrategy<T> {
private final int minSubscribers;
private final Supplier<ReplayAccumulator<T>> accumulatorSupplier;
private final boolean cancelUpstream;
private final int queueLimitHint;
private final Function<Throwable, Completable> terminalResubscribe;

private DefaultReplayStrategy(
final int minSubscribers, final Supplier<ReplayAccumulator<T>> accumulatorSupplier,
final boolean cancelUpstream, final int queueLimitHint,
final Function<Throwable, Completable> terminalResubscribe) {
this.minSubscribers = minSubscribers;
this.accumulatorSupplier = accumulatorSupplier;
this.cancelUpstream = cancelUpstream;
this.queueLimitHint = queueLimitHint;
this.terminalResubscribe = terminalResubscribe;
}

@Override
public int minSubscribers() {
return minSubscribers;
}

@Override
public Supplier<ReplayAccumulator<T>> accumulatorSupplier() {
return accumulatorSupplier;
}

@Override
public boolean cancelUpstream() {
return cancelUpstream;
}

@Override
public int queueLimitHint() {
return queueLimitHint;
}

@Override
public Function<Throwable, Completable> terminalResubscribe() {
return terminalResubscribe;
}
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,309 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.api;

import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.CsvSource;
import org.junit.jupiter.params.provider.ValueSource;

import java.time.Duration;
import java.util.concurrent.CyclicBarrier;
import java.util.concurrent.Future;
import java.util.function.Consumer;
import java.util.function.Function;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.SourceAdapters.toSource;
import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
import static java.time.Duration.ofMillis;
import static java.util.concurrent.TimeUnit.MILLISECONDS;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;

final class ReplayPublisherTest extends MulticastPublisherTest {
private final TestPublisherSubscriber<Integer> subscriber4 = new TestPublisherSubscriber<>();
private final TestExecutor executor = new TestExecutor();

@AfterEach
void tearDown() throws Exception {
executor.closeAsync().toFuture().get();
}

@Override
<T> Publisher<T> applyOperator(Publisher<T> source, int minSubscribers) {
return source.replay(new ReplayStrategyBuilder<T>(EmptyReplayAccumulator::emptyAccumulator)
.minSubscribers(minSubscribers).build());
}

@Override
<T> Publisher<T> applyOperator(Publisher<T> source, int minSubscribers, boolean cancelUpstream) {
return source.replay(new ReplayStrategyBuilder<T>(EmptyReplayAccumulator::emptyAccumulator)
.cancelUpstream(cancelUpstream)
.minSubscribers(minSubscribers).build());
}

@Override
<T> Publisher<T> applyOperator(Publisher<T> source, int minSubscribers, int queueLimit,
Function<Throwable, Completable> terminalResubscribe) {
return source.replay(new ReplayStrategyBuilder<T>(EmptyReplayAccumulator::emptyAccumulator)
.queueLimitHint(queueLimit)
.terminalResubscribe(terminalResubscribe)
.minSubscribers(minSubscribers).build());
}

@Override
<T> Publisher<T> applyOperator(Publisher<T> source, int minSubscribers, int queueLimit) {
return source.replay(new ReplayStrategyBuilder<T>(EmptyReplayAccumulator::emptyAccumulator)
.queueLimitHint(queueLimit)
.minSubscribers(minSubscribers).build());
}

@Override
<T> Publisher<T> applyOperator(Publisher<T> source, int minSubscribers, int queueLimit, boolean cancelUpstream) {
return source.replay(new ReplayStrategyBuilder<T>(EmptyReplayAccumulator::emptyAccumulator)
.queueLimitHint(queueLimit)
.cancelUpstream(cancelUpstream)
.minSubscribers(minSubscribers).build());
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void twoSubscribersHistory(boolean onError) {
Publisher<Integer> publisher = source.replay(2);
toSource(publisher).subscribe(subscriber1);
subscriber1.awaitSubscription().request(4);
assertThat(subscription.requested(), is(4L));
source.onNext(1, 2, null);
assertThat(subscriber1.takeOnNext(3), contains(1, 2, null));

toSource(publisher).subscribe(subscriber2);
subscriber2.awaitSubscription().request(4);
assertThat(subscription.requested(), is(4L));

assertThat(subscriber2.takeOnNext(2), contains(2, null));

source.onNext(4);
assertThat(subscriber1.takeOnNext(), is(4));
assertThat(subscriber2.takeOnNext(), is(4));

twoSubscribersTerminate(onError);
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void subscribeAfterTerminalDeliversHistory(boolean onError) {
Publisher<Integer> publisher = source.replay(2);
toSource(publisher).subscribe(subscriber1);
subscriber1.awaitSubscription().request(4);
assertThat(subscription.requested(), is(4L));
source.onNext(1, 2, 3);
assertThat(subscriber1.takeOnNext(3), contains(1, 2, 3));
if (onError) {
source.onError(DELIBERATE_EXCEPTION);
assertThat(subscriber1.awaitOnError(), is(DELIBERATE_EXCEPTION));
} else {
source.onComplete();
subscriber1.awaitOnComplete();
}

toSource(publisher).subscribe(subscriber2);
subscriber2.awaitSubscription().request(4);
assertThat(subscriber2.takeOnNext(2), contains(2, 3));
if (onError) {
assertThat(subscriber2.awaitOnError(), is(DELIBERATE_EXCEPTION));
} else {
subscriber2.awaitOnComplete();
}
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void threeSubscribersSum(boolean onError) {
Publisher<Integer> publisher = source.replay(SumReplayAccumulator::new);
toSource(publisher).subscribe(subscriber1);
subscriber1.awaitSubscription().request(4);
assertThat(subscription.requested(), is(4L));
source.onNext(1, 2, 3);
assertThat(subscriber1.takeOnNext(3), contains(1, 2, 3));

toSource(publisher).subscribe(subscriber2);
subscriber2.awaitSubscription().request(4);
assertThat(subscription.requested(), is(4L));

assertThat(subscriber2.takeOnNext(), equalTo(6));

source.onNext(4);
assertThat(subscriber1.takeOnNext(), is(4));
assertThat(subscriber2.takeOnNext(), is(4));

toSource(publisher).subscribe(subscriber3);
subscriber3.awaitSubscription().request(4);
assertThat(subscription.requested(), is(4L));
assertThat(subscriber3.takeOnNext(), equalTo(10));

subscriber1.awaitSubscription().request(1);
assertThat(subscription.requested(), is(5L));
source.onNext(5);

assertThat(subscriber1.takeOnNext(), is(5));
assertThat(subscriber2.takeOnNext(), is(5));
assertThat(subscriber3.takeOnNext(), is(5));

threeSubscribersTerminate(onError);
}

@ParameterizedTest
@ValueSource(booleans = {true, false})
void threeSubscribersTTL(boolean onError) {
final Duration ttl = ofMillis(2);
Publisher<Integer> publisher = source.replay(2, ttl, executor);
toSource(publisher).subscribe(subscriber1);
subscriber1.awaitSubscription().request(4);
assertThat(subscription.requested(), is(4L));
source.onNext(1, 2);
executor.advanceTimeBy(1, MILLISECONDS);
source.onNext((Integer) null);
assertThat(subscriber1.takeOnNext(3), contains(1, 2, null));

toSource(publisher).subscribe(subscriber2);
subscriber2.awaitSubscription().request(4);
assertThat(subscriber2.takeOnNext(2), contains(2, null));

executor.advanceTimeBy(1, MILLISECONDS);
toSource(publisher).subscribe(subscriber3);
subscriber3.awaitSubscription().request(4);
assertThat(subscriber3.takeOnNext(), equalTo(null));

source.onNext(4);
assertThat(subscriber1.takeOnNext(), equalTo(4));
assertThat(subscriber2.takeOnNext(), equalTo(4));
assertThat(subscriber3.takeOnNext(), equalTo(4));

subscriber1.awaitSubscription().request(10);
subscriber2.awaitSubscription().request(10);
subscriber3.awaitSubscription().request(10);
executor.advanceTimeBy(ttl.toMillis(), MILLISECONDS);
toSource(publisher).subscribe(subscriber4);
subscriber4.awaitSubscription().request(4);
assertThat(subscriber4.pollOnNext(10, MILLISECONDS), nullValue());

threeSubscribersTerminate(onError);
}

@ParameterizedTest(name = "{displayName} [{index}] expectedSubscribers={0} expectedSum={1}")
@CsvSource(value = {"500,500", "50,50", "50,500", "500,50"})
void concurrentSubscribes(final int expectedSubscribers, final long expectedSum) throws Exception {
Publisher<Integer> replay = source.replay(SumReplayAccumulator::new);
CyclicBarrier startBarrier = new CyclicBarrier(expectedSubscribers + 1);
Completable[] completables = new Completable[expectedSubscribers];
@SuppressWarnings("unchecked")
TestPublisherSubscriber<Integer>[] subscribers = (TestPublisherSubscriber<Integer>[])
new TestPublisherSubscriber[expectedSubscribers];
Executor executor = Executors.newCachedThreadExecutor();
try {
for (int i = 0; i < subscribers.length; ++i) {
final TestPublisherSubscriber<Integer> currSubscriber = new TestPublisherSubscriber<>();
subscribers[i] = currSubscriber;
completables[i] = executor.submit(() -> {
try {
startBarrier.await();
toSource(replay).subscribe(currSubscriber);
currSubscriber.awaitSubscription().request(expectedSum);
} catch (Exception e) {
throw new RuntimeException(e);
}
});
}

Future<Void> future = Completable.mergeAll(completables.length, completables).toFuture();
startBarrier.await();
for (int i = 0; i < expectedSum; ++i) {
subscription.awaitRequestN(i + 1);
source.onNext(1);
}

future.get();
source.onComplete(); // deliver terminal after all requests have been delivered.

for (final TestPublisherSubscriber<Integer> currSubscriber : subscribers) {
int numOnNext = 0;
long currSum = 0;
while (currSum < expectedSum) {
Integer next = currSubscriber.takeOnNext();
++numOnNext;
if (next != null) {
currSum += next;
}
}
try {
assertThat(currSum, equalTo(expectedSum));
currSubscriber.awaitOnComplete();
} catch (Throwable cause) {
throw new AssertionError("failure numOnNext=" + numOnNext, cause);
}
}

subscription.awaitRequestN(expectedSum);
assertThat(subscription.isCancelled(), is(false));
} finally {
executor.closeAsync().toFuture().get();
}
}

private static final class EmptyReplayAccumulator<T> implements ReplayAccumulator<T> {
static final ReplayAccumulator<?> INSTANCE = new EmptyReplayAccumulator<>();

private EmptyReplayAccumulator() {
}

@SuppressWarnings("unchecked")
static <T> ReplayAccumulator<T> emptyAccumulator() {
return (ReplayAccumulator<T>) INSTANCE;
}

@Override
public void accumulate(@Nullable final T t) {
}

@Override
public void deliverAccumulation(final Consumer<T> consumer) {
}
}

private static final class SumReplayAccumulator implements ReplayAccumulator<Integer> {
private int sum;

@Override
public void accumulate(@Nullable final Integer integer) {
if (integer != null) {
sum += integer;
}
}

@Override
public void deliverAccumulation(final Consumer<Integer> consumer) {
if (sum != 0) {
consumer.accept(sum);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Copyright © 2023 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.concurrent.reactivestreams.tck;

import io.servicetalk.concurrent.api.Publisher;

import org.testng.annotations.Test;

@Test
public class PublisherReplayTckTest extends AbstractPublisherOperatorTckTest<Integer> {
@Override
protected Publisher<Integer> composePublisher(Publisher<Integer> publisher, int elements) {
return publisher.replay(1);
}
}

0 comments on commit 8d3d602

Please sign in to comment.