diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java index fe51359b25..3847adf03b 100644 --- a/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java +++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/OrderedStreamAsync.java @@ -24,16 +24,12 @@ import org.apache.ratis.datastream.impl.DataStreamRequestByteBuffer; import org.apache.ratis.datastream.impl.DataStreamRequestFilePositionCount; import org.apache.ratis.io.FilePositionCount; -import org.apache.ratis.io.StandardWriteOption; import org.apache.ratis.protocol.DataStreamReply; import org.apache.ratis.protocol.DataStreamRequest; import org.apache.ratis.protocol.DataStreamRequestHeader; -import org.apache.ratis.protocol.exceptions.TimeoutIOException; import org.apache.ratis.util.IOUtils; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.SlidingWindow; -import org.apache.ratis.util.TimeDuration; -import org.apache.ratis.util.TimeoutExecutor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -105,15 +101,10 @@ public String toString() { private final DataStreamClientRpc dataStreamClientRpc; private final Semaphore requestSemaphore; - private final TimeDuration requestTimeout; - private final TimeDuration closeTimeout; - private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance(); OrderedStreamAsync(DataStreamClientRpc dataStreamClientRpc, RaftProperties properties){ this.dataStreamClientRpc = dataStreamClientRpc; this.requestSemaphore = new Semaphore(RaftClientConfigKeys.DataStream.outstandingRequestsMax(properties)); - this.requestTimeout = RaftClientConfigKeys.DataStream.requestTimeout(properties); - this.closeTimeout = requestTimeout.multiply(2); } CompletableFuture sendRequest(DataStreamRequestHeader header, Object data, @@ -149,9 +140,6 @@ private void sendRequestToNetwork(DataStreamWindowRequest request, request.getDataStreamRequest()); long seqNum = request.getSeqNum(); - final boolean isClose = request.getDataStreamRequest().getWriteOptionList().contains(StandardWriteOption.CLOSE); - scheduleWithTimeout(request, isClose? closeTimeout: requestTimeout); - requestFuture.thenApply(reply -> { slidingWindow.receiveReply( seqNum, reply, r -> sendRequestToNetwork(r, slidingWindow)); @@ -166,13 +154,4 @@ private void sendRequestToNetwork(DataStreamWindowRequest request, return null; }); } - - private void scheduleWithTimeout(DataStreamWindowRequest request, TimeDuration timeout) { - scheduler.onTimeout(timeout, () -> { - if (!request.getReplyFuture().isDone()) { - request.getReplyFuture().completeExceptionally( - new TimeoutIOException("Timeout " + timeout + ": Failed to send " + request)); - } - }, LOG, () -> "Failed to completeExceptionally for " + request); - } } diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java new file mode 100644 index 0000000000..fc97b6fe34 --- /dev/null +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientReplies.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.ratis.netty.client; + +import org.apache.ratis.proto.RaftProtos.DataStreamPacketHeaderProto.Type; +import org.apache.ratis.protocol.ClientInvocationId; +import org.apache.ratis.protocol.DataStreamPacket; +import org.apache.ratis.protocol.DataStreamReply; +import org.apache.ratis.thirdparty.io.netty.util.concurrent.ScheduledFuture; +import org.apache.ratis.util.MemoizedSupplier; +import org.apache.ratis.util.Preconditions; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicReference; + +public class NettyClientReplies { + public static final Logger LOG = LoggerFactory.getLogger(NettyClientReplies.class); + + private final ConcurrentMap replies = new ConcurrentHashMap<>(); + + ReplyMap getReplyMap(ClientInvocationId clientInvocationId) { + final MemoizedSupplier q = MemoizedSupplier.valueOf(() -> new ReplyMap(clientInvocationId)); + return replies.computeIfAbsent(clientInvocationId, key -> q.get()); + } + + class ReplyMap { + private final ClientInvocationId clientInvocationId; + private final Map map = new ConcurrentHashMap<>(); + + ReplyMap(ClientInvocationId clientInvocationId) { + this.clientInvocationId = clientInvocationId; + } + + ReplyEntry submitRequest(RequestEntry requestEntry, boolean isClose, CompletableFuture f) { + LOG.debug("put {} to the map for {}", requestEntry, clientInvocationId); + final MemoizedSupplier replySupplier = MemoizedSupplier.valueOf(() -> new ReplyEntry(isClose, f)); + return map.computeIfAbsent(requestEntry, r -> replySupplier.get()); + } + + void receiveReply(DataStreamReply reply) { + final RequestEntry requestEntry = new RequestEntry(reply); + final ReplyEntry replyEntry = map.remove(requestEntry); + LOG.debug("remove: {}; replyEntry: {}; reply: {}", requestEntry, replyEntry, reply); + if (replyEntry == null) { + LOG.debug("Request not found: {}", this); + return; + } + replyEntry.complete(reply); + if (!reply.isSuccess()) { + failAll("a request failed with " + reply); + } else if (replyEntry.isClosed()) { // stream closed clean up reply map + removeThisMap(); + } + } + + private void removeThisMap() { + final ReplyMap removed = replies.remove(clientInvocationId); + Preconditions.assertSame(removed, this, "removed"); + } + + void completeExceptionally(Throwable e) { + removeThisMap(); + for (ReplyEntry entry : map.values()) { + entry.completeExceptionally(e); + } + map.clear(); + } + + private void failAll(String message) { + completeExceptionally(new IllegalStateException(this + ": " + message)); + } + + void fail(RequestEntry requestEntry) { + map.remove(requestEntry); + failAll(requestEntry + " failed "); + } + + @Override + public String toString() { + final StringBuilder builder = new StringBuilder(); + for (RequestEntry requestEntry : map.keySet()) { + builder.append(requestEntry).append(", "); + } + return builder.toString(); + } + } + + static class RequestEntry { + private final long streamOffset; + private final Type type; + + RequestEntry(DataStreamPacket packet) { + this.streamOffset = packet.getStreamOffset(); + this.type = packet.getType(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + final RequestEntry that = (RequestEntry) o; + return streamOffset == that.streamOffset + && type == that.type; + } + + @Override + public int hashCode() { + return Objects.hash(type, streamOffset); + } + + @Override + public String toString() { + return "Request{" + + "streamOffset=" + streamOffset + + ", type=" + type + + '}'; + } + } + + static class ReplyEntry { + private final boolean isClosed; + private final CompletableFuture replyFuture; + private final AtomicReference> timeoutFuture = new AtomicReference<>(); + + ReplyEntry(boolean isClosed, CompletableFuture replyFuture) { + this.isClosed = isClosed; + this.replyFuture = replyFuture; + } + + boolean isClosed() { + return isClosed; + } + + void complete(DataStreamReply reply) { + cancelTimeoutFuture(); + replyFuture.complete(reply); + } + + void completeExceptionally(Throwable t) { + cancelTimeoutFuture(); + replyFuture.completeExceptionally(t); + } + + private void cancelTimeoutFuture() { + Optional.ofNullable(timeoutFuture.get()).ifPresent(f -> f.cancel(false)); + } + + void setTimeoutFuture(ScheduledFuture timeoutFuture) { + this.timeoutFuture.compareAndSet(null, timeoutFuture); + } + } +} diff --git a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java index 51326d13ea..f815bcffe3 100644 --- a/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java +++ b/ratis-netty/src/main/java/org/apache/ratis/netty/client/NettyClientStreamRpc.java @@ -33,6 +33,7 @@ import org.apache.ratis.protocol.DataStreamRequest; import org.apache.ratis.protocol.RaftPeer; import org.apache.ratis.protocol.exceptions.AlreadyClosedException; +import org.apache.ratis.protocol.exceptions.TimeoutIOException; import org.apache.ratis.security.TlsConf; import org.apache.ratis.thirdparty.io.netty.bootstrap.Bootstrap; import org.apache.ratis.thirdparty.io.netty.buffer.ByteBuf; @@ -51,26 +52,21 @@ import org.apache.ratis.thirdparty.io.netty.handler.codec.ByteToMessageDecoder; import org.apache.ratis.thirdparty.io.netty.handler.codec.MessageToMessageEncoder; import org.apache.ratis.thirdparty.io.netty.handler.ssl.SslContext; +import org.apache.ratis.thirdparty.io.netty.util.concurrent.ScheduledFuture; import org.apache.ratis.util.JavaUtils; import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.NetUtils; import org.apache.ratis.util.SizeInBytes; import org.apache.ratis.util.TimeDuration; -import org.apache.ratis.util.TimeoutExecutor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import java.io.IOException; import java.net.InetSocketAddress; import java.nio.ByteBuffer; -import java.util.Iterator; import java.util.List; import java.util.Optional; -import java.util.Queue; import java.util.concurrent.CompletableFuture; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentLinkedQueue; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; @@ -115,39 +111,6 @@ void shutdownGracefully() { } } - static class ReplyQueue implements Iterable> { - static final ReplyQueue EMPTY = new ReplyQueue(); - - private final Queue> queue = new ConcurrentLinkedQueue<>(); - private int emptyId; - - /** @return an empty ID if the queue is empty; otherwise, the queue is non-empty, return null. */ - synchronized Integer getEmptyId() { - return queue.isEmpty()? emptyId: null; - } - - synchronized boolean offer(CompletableFuture f) { - if (queue.offer(f)) { - emptyId++; - return true; - } - return false; - } - - CompletableFuture poll() { - return queue.poll(); - } - - int size() { - return queue.size(); - } - - @Override - public Iterator> iterator() { - return queue.iterator(); - } - } - static class Connection { static final TimeDuration RECONNECT = TimeDuration.valueOf(100, TimeUnit.MILLISECONDS); @@ -275,17 +238,19 @@ synchronized boolean shouldFlush(boolean force, int countMin, SizeInBytes bytesM private final String name; private final Connection connection; + private final NettyClientReplies replies = new NettyClientReplies(); + private final TimeDuration requestTimeout; + private final TimeDuration closeTimeout; + private final int flushRequestCountMin; private final SizeInBytes flushRequestBytesMin; private final OutstandingRequests outstandingRequests = new OutstandingRequests(); - private final ConcurrentMap replies = new ConcurrentHashMap<>(); - private final TimeDuration replyQueueGracePeriod; - private final TimeoutExecutor timeoutScheduler = TimeoutExecutor.getInstance(); - public NettyClientStreamRpc(RaftPeer server, TlsConf tlsConf, RaftProperties properties) { this.name = JavaUtils.getClassSimpleName(getClass()) + "->" + server; - this.replyQueueGracePeriod = NettyConfigKeys.DataStream.Client.replyQueueGracePeriod(properties); + this.requestTimeout = RaftClientConfigKeys.DataStream.requestTimeout(properties); + this.closeTimeout = requestTimeout.multiply(2); + this.flushRequestCountMin = RaftClientConfigKeys.DataStream.flushRequestCountMin(properties); this.flushRequestBytesMin = RaftClientConfigKeys.DataStream.flushRequestBytesMin(properties); @@ -299,8 +264,6 @@ public NettyClientStreamRpc(RaftPeer server, TlsConf tlsConf, RaftProperties pro private ChannelInboundHandler getClientHandler(){ return new ChannelInboundHandlerAdapter(){ - private ClientInvocationId clientInvocationId; - @Override public void channelRead(ChannelHandlerContext ctx, Object msg) { if (!(msg instanceof DataStreamReply)) { @@ -309,29 +272,19 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { } final DataStreamReply reply = (DataStreamReply) msg; LOG.debug("{}: read {}", this, reply); - clientInvocationId = ClientInvocationId.valueOf(reply.getClientId(), reply.getStreamId()); - final ReplyQueue queue = reply.isSuccess() ? replies.get(clientInvocationId) : - replies.remove(clientInvocationId); - if (queue != null) { - final CompletableFuture f = queue.poll(); - if (f != null) { - f.complete(reply); - - if (!reply.isSuccess() && queue.size() > 0) { - final IllegalStateException e = new IllegalStateException( - this + ": an earlier request failed with " + reply); - queue.forEach(future -> future.completeExceptionally(e)); - } + final ClientInvocationId clientInvocationId = ClientInvocationId.valueOf( + reply.getClientId(), reply.getStreamId()); + final NettyClientReplies.ReplyMap replyMap = replies.getReplyMap(clientInvocationId); + if (replyMap == null) { + LOG.error("{}: {} replyMap not found for reply: {}", this, clientInvocationId, reply); + return; + } - final Integer emptyId = queue.getEmptyId(); - if (emptyId != null) { - timeoutScheduler.onTimeout(replyQueueGracePeriod, - // remove the queue if the same queue has been empty for the entire grace period. - () -> replies.computeIfPresent(clientInvocationId, - (key, q) -> q == queue && emptyId.equals(q.getEmptyId())? null: q), - LOG, () -> "Timeout check failed, clientInvocationId=" + clientInvocationId); - } - } + try { + replyMap.receiveReply(reply); + } catch (Throwable cause) { + LOG.warn(name + ": channelRead error:", cause); + replyMap.completeExceptionally(cause); } } @@ -339,10 +292,6 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { LOG.warn(name + ": exceptionCaught", cause); - Optional.ofNullable(clientInvocationId) - .map(replies::remove) - .orElse(ReplyQueue.EMPTY) - .forEach(f -> f.completeExceptionally(cause)); ctx.close(); } @@ -417,24 +366,45 @@ protected void decode(ChannelHandlerContext context, ByteBuf buf, List o public CompletableFuture streamAsync(DataStreamRequest request) { final CompletableFuture f = new CompletableFuture<>(); ClientInvocationId clientInvocationId = ClientInvocationId.valueOf(request.getClientId(), request.getStreamId()); - final ReplyQueue q = replies.computeIfAbsent(clientInvocationId, key -> new ReplyQueue()); - if (!q.offer(f)) { - f.completeExceptionally(new IllegalStateException(this + ": Failed to offer a future for " + request)); - return f; - } - final Channel channel = connection.getChannelUninterruptibly(); - if (channel == null) { - f.completeExceptionally(new AlreadyClosedException(this + ": Failed to send " + request)); - return f; + final boolean isClose = request.getWriteOptionList().contains(StandardWriteOption.CLOSE); + + final NettyClientReplies.ReplyMap replyMap = replies.getReplyMap(clientInvocationId); + final ChannelFuture channelFuture; + final Channel channel; + final NettyClientReplies.RequestEntry requestEntry = new NettyClientReplies.RequestEntry(request); + final NettyClientReplies.ReplyEntry replyEntry; + LOG.debug("{}: write begin {}", this, request); + synchronized (replyMap) { + channel = connection.getChannelUninterruptibly(); + if (channel == null) { + f.completeExceptionally(new AlreadyClosedException(this + ": Failed to send " + request)); + return f; + } + replyEntry = replyMap.submitRequest(requestEntry, isClose, f); + final Function writeMethod = outstandingRequests.write(request)? + channel::writeAndFlush: channel::write; + channelFuture = writeMethod.apply(request); } - LOG.debug("{}: write {}", this, request); - final Function writeMethod = outstandingRequests.write(request)? - channel::writeAndFlush: channel::write; - writeMethod.apply(request).addListener(future -> { + channelFuture.addListener(future -> { if (!future.isSuccess()) { - final IOException e = new IOException(this + ": Failed to send " + request, future.cause()); - LOG.error("Channel write failed", e); + final IOException e = new IOException(this + ": Failed to send " + request + " to " + channel.remoteAddress(), + future.cause()); f.completeExceptionally(e); + replyMap.fail(requestEntry); + LOG.error("Channel write failed", e); + } else { + LOG.debug("{}: write after {}", this, request); + + final TimeDuration timeout = isClose ? closeTimeout : requestTimeout; + // if reply success cancel this future + final ScheduledFuture timeoutFuture = channel.eventLoop().schedule(() -> { + if (!f.isDone()) { + f.completeExceptionally(new TimeoutIOException( + "Timeout " + timeout + ": Failed to send " + request + " channel: " + channel)); + replyMap.fail(requestEntry); + } + }, timeout.toLong(timeout.getUnit()), timeout.getUnit()); + replyEntry.setTimeoutFuture(timeoutFuture); } }); return f;