diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedAsync.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedAsync.java index a1aa58681c..e17443d661 100644 --- a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedAsync.java +++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedAsync.java @@ -83,6 +83,10 @@ public void setFirstRequest() { isFirst = true; } + public long getCallId() { + return callId; + } + @Override public long getSeqNum() { return seqNum; @@ -133,7 +137,7 @@ private OrderedAsync(RaftClientImpl client, RaftProperties properties) { } private void resetSlidingWindow(RaftClientRequest request) { - getSlidingWindow(request).resetFirstSeqNum(); + getSlidingWindow(request).resetFirstSeqNum(request.getCallId()); } private SlidingWindow.Client getSlidingWindow(RaftClientRequest request) { diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java index 989c00cbbc..ac91873521 100644 --- a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java +++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java @@ -73,6 +73,10 @@ public long getSeqNum() { return seqNum; } + public long getCallId() { + return -1; + } + @Override public void setReply(DataStreamReply dataStreamReply) { replyFuture.complete(dataStreamReply); diff --git a/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java b/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java index 316604db07..e9608d457c 100644 --- a/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java +++ b/ratis-common/src/main/java/org/apache/ratis/util/SlidingWindow.java @@ -27,6 +27,7 @@ import java.util.List; import java.util.SortedMap; import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentSkipListMap; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; @@ -56,6 +57,7 @@ interface Request { interface ClientSideRequest extends Request { void setFirstRequest(); + long getCallId(); } interface ServerSideRequest extends Request { @@ -228,13 +230,14 @@ class Client, REPLY> { private final RequestMap requests; /** Delayed requests. */ private final DelayedRequests delayedRequests = new DelayedRequests(); + private final ConcurrentHashMap map = new ConcurrentHashMap<>(); /** The seqNum for the next new request. */ private long nextSeqNum = 1; /** The seqNum of the first request. */ - private long firstSeqNum = -1; + private volatile long firstSeqNum = -1; /** Is the first request replied? */ - private boolean firstReplied; + private volatile boolean firstReplied; /** The exception, if there is any. */ private Throwable exception; @@ -300,6 +303,7 @@ private boolean sendOrDelayRequest(REQUEST request, Consumer sendMethod if (firstReplied) { // already received the reply for the first request, submit any request. + map.put(request.getCallId(), getFirstSeqNum()); sendMethod.accept(request); return true; } @@ -309,6 +313,7 @@ private boolean sendOrDelayRequest(REQUEST request, Consumer sendMethod LOG.debug("{}: detect firstSubmitted {} in {}", requests.getName(), request, this); firstSeqNum = seqNum; request.setFirstRequest(); + map.put(request.getCallId(), getFirstSeqNum()); sendMethod.accept(request); return true; } @@ -333,7 +338,9 @@ public synchronized void retry(REQUEST request, Consumer sendMethod) { private void removeRepliedFromHead() { for (final Iterator i = requests.iterator(); i.hasNext(); i.remove()) { final REQUEST r = i.next(); - if (!r.hasReply()) { + if (r.hasReply()) { + map.remove(r.getCallId()); + } else { return; } } @@ -374,10 +381,12 @@ private void trySendDelayed(Consumer sendMethod) { } /** Reset the {@link #firstSeqNum} The stream has an error. */ - public synchronized void resetFirstSeqNum() { - firstSeqNum = -1; - firstReplied = false; - LOG.debug("After resetFirstSeqNum: {}", this); + public synchronized void resetFirstSeqNum(long callId) { + if (callId == -1 || getFirstSeqNum() == map.get(callId)) { + firstSeqNum = -1; + firstReplied = false; + LOG.debug("After resetFirstSeqNum: {}", this); + } } /** Fail all requests starting from the given seqNum. */ @@ -409,6 +418,10 @@ private void alreadyClosed(REQUEST request, Throwable e) { public synchronized boolean isFirst(long seqNum) { return seqNum == (firstSeqNum != -1 ? firstSeqNum : requests.firstSeqNum()); } + + public long getFirstSeqNum() { + return firstSeqNum != -1 ? firstSeqNum : requests.firstSeqNum(); + } } /**