diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java index 3c0f14fb71..db19831955 100644 --- a/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java +++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/ClientProtoUtils.java @@ -131,6 +131,7 @@ static RaftRpcRequestProto.Builder toRaftRpcRequestProtoBuilder(RaftClientReques return b.setCallId(request.getCallId()) .setToLeader(request.isToLeader()) + .addAllRepliedCallIds(request.getRepliedCallIds()) .setTimeoutMs(request.getTimeoutMs()); } @@ -192,6 +193,7 @@ static RaftClientRequest toRaftClientRequest(RaftClientRequestProto p) { .setCallId(request.getCallId()) .setMessage(toMessage(p.getMessage())) .setType(type) + .setRepliedCallIds(request.getRepliedCallIdsList()) .setRoutingTable(getRoutingTable(request)) .setTimeoutMs(request.getTimeoutMs()) .build(); diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java index 9847beed7c..ec16763c2c 100644 --- a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java +++ b/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientImpl.java @@ -59,6 +59,8 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.TreeSet; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; @@ -127,6 +129,47 @@ void set(Collection newPeers) { } } + static class RepliedCallIds { + private final Object name; + /** The replied callIds. */ + private Set replied = new TreeSet<>(); + /** + * Map: callId to-be-sent -> replied callIds to-be-included. + * When retrying the same callId, the request will include the same set of replied callIds. + * + * @see RaftClientRequest#getRepliedCallIds() + */ + private final ConcurrentMap> sent = new ConcurrentHashMap<>(); + + RepliedCallIds(Object name) { + this.name = name; + } + + /** The given callId is replied. */ + void add(long repliedCallId) { + LOG.debug("{}: add replied callId {}", name, repliedCallId); + synchronized (this) { + // synchronized to avoid adding to a previous set. + replied.add(repliedCallId); + } + sent.remove(repliedCallId); + } + + /** @return the replied callIds for the given callId. */ + Iterable get(long callId) { + final Supplier> supplier = MemoizedSupplier.valueOf(this::getAndReset); + final Set set = Collections.unmodifiableSet(sent.computeIfAbsent(callId, cid -> supplier.get())); + LOG.debug("{}: get {} returns {}", name, callId, set); + return set; + } + + private synchronized Set getAndReset() { + final Set previous = replied; + replied = new TreeSet<>(); + return previous; + } + } + private final ClientId clientId; private final RaftClientRpc clientRpc; private final RaftPeerList peers = new RaftPeerList(); @@ -134,6 +177,8 @@ void set(Collection newPeers) { private final RetryPolicy retryPolicy; private volatile RaftPeerId leaderId; + /** The callIds of the replied requests. */ + private final RepliedCallIds repliedCallIds; private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance(); @@ -158,6 +203,7 @@ void set(Collection newPeers) { this.leaderId = Objects.requireNonNull(computeLeaderId(leaderId, group), () -> "this.leaderId is set to null, leaderId=" + leaderId + ", group=" + group); + this.repliedCallIds = new RepliedCallIds(clientId); this.retryPolicy = Objects.requireNonNull(retryPolicy, "retry policy can't be null"); clientRpc.addRaftPeers(group.getPeers()); @@ -241,7 +287,8 @@ RaftClientRequest newRaftClientRequest( if (server != null) { b.setServerId(server); } else { - b.setLeaderId(getLeaderId()); + b.setLeaderId(getLeaderId()) + .setRepliedCallIds(repliedCallIds.get(callId)); } return b.setClientId(clientId) .setGroupId(groupId) @@ -307,8 +354,14 @@ Throwable noMoreRetries(ClientRetryEvent event) { } RaftClientReply handleReply(RaftClientRequest request, RaftClientReply reply) { - if (request.isToLeader() && reply != null && reply.getException() == null) { - LEADER_CACHE.put(reply.getRaftGroupId(), reply.getServerId()); + if (request.isToLeader() && reply != null) { + if (!request.getType().isReadOnly()) { + repliedCallIds.add(reply.getCallId()); + } + + if (reply.getException() == null) { + LEADER_CACHE.put(reply.getRaftGroupId(), reply.getServerId()); + } } return reply; } diff --git a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java index 7c55a18221..9d853b48b6 100644 --- a/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java +++ b/ratis-common/src/main/java/org/apache/ratis/protocol/RaftClientRequest.java @@ -21,7 +21,9 @@ import org.apache.ratis.util.Preconditions; import org.apache.ratis.util.ProtoUtils; +import java.util.Collections; import java.util.Objects; +import java.util.Optional; import static org.apache.ratis.proto.RaftProtos.RaftClientRequestProto.TypeCase.*; @@ -266,6 +268,7 @@ public static class Builder { private RaftGroupId groupId; private long callId; private boolean toLeader; + private Iterable repliedCallIds = Collections.emptyList(); private Message message; private Type type; @@ -304,6 +307,11 @@ public Builder setCallId(long callId) { return this; } + public Builder setRepliedCallIds(Iterable repliedCallIds) { + this.repliedCallIds = repliedCallIds; + return this; + } + public Builder setMessage(Message message) { this.message = message; return this; @@ -350,6 +358,7 @@ public static RaftClientRequest toWriteRequest(RaftClientRequest r, Message mess private final Message message; private final Type type; + private final Iterable repliedCallIds; private final SlidingWindowEntry slidingWindowEntry; private final RoutingTable routingTable; @@ -386,8 +395,8 @@ private RaftClientRequest(Builder b) { this.message = b.message; this.type = b.type; - this.slidingWindowEntry = b.slidingWindowEntry != null ? b.slidingWindowEntry - : SlidingWindowEntry.getDefaultInstance(); + this.repliedCallIds = Optional.ofNullable(b.repliedCallIds).orElseGet(Collections::emptyList); + this.slidingWindowEntry = b.slidingWindowEntry; this.routingTable = b.routingTable; this.timeoutMs = b.timeoutMs; } @@ -401,6 +410,10 @@ public boolean isToLeader() { return toLeader; } + public Iterable getRepliedCallIds() { + return repliedCallIds; + } + public SlidingWindowEntry getSlidingWindowEntry() { return slidingWindowEntry; } diff --git a/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java b/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java index 85e6ef06be..abc24cc09b 100644 --- a/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java +++ b/ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java @@ -22,11 +22,11 @@ /** * A long ID for RPC calls. - * + *

* This class is threadsafe. */ public final class CallId { - private static final AtomicLong CALL_ID_COUNTER = new AtomicLong(); + private static final AtomicLong CALL_ID_COUNTER = new AtomicLong(1); private static final Comparator COMPARATOR = (left, right) -> { final long diff = left - right; diff --git a/ratis-proto/src/main/proto/Raft.proto b/ratis-proto/src/main/proto/Raft.proto index 49a107c455..d8a1b626a7 100644 --- a/ratis-proto/src/main/proto/Raft.proto +++ b/ratis-proto/src/main/proto/Raft.proto @@ -117,6 +117,7 @@ message RaftRpcRequestProto { uint64 callId = 4; bool toLeader = 5; + repeated uint64 repliedCallIds = 12; // The call ids of the replied requests uint64 timeoutMs = 13; RoutingTableProto routingTable = 14; SlidingWindowEntry slidingWindowEntry = 15; diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java index 667e611b42..8005be8947 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RaftServerImpl.java @@ -900,6 +900,8 @@ public CompletableFuture submitClientRequestAsync( } private CompletableFuture replyFuture(RaftClientRequest request) throws IOException { + retryCache.invalidateRepliedRequests(request); + final TypeCase type = request.getType().getTypeCase(); switch (type) { case STALEREAD: @@ -925,7 +927,7 @@ private CompletableFuture writeAsync(RaftClientRequest request) } // query the retry cache - final RetryCacheImpl.CacheQueryResult queryResult = retryCache.queryCache(ClientInvocationId.valueOf(request)); + final RetryCacheImpl.CacheQueryResult queryResult = retryCache.queryCache(request); final CacheEntry cacheEntry = queryResult.getEntry(); if (queryResult.isRetry()) { // return the cached future. @@ -1784,7 +1786,7 @@ private CompletableFuture replyPendingRequest( ClientInvocationId invocationId, long logIndex, CompletableFuture stateMachineFuture) { // update the retry cache final CacheEntry cacheEntry = retryCache.getOrCreateEntry(invocationId); - Preconditions.assertTrue(cacheEntry != null); + Objects.requireNonNull(cacheEntry , "cacheEntry == null"); if (getInfo().isLeader() && cacheEntry.isCompletedNormally()) { LOG.warn("{} retry cache entry of leader should be pending: {}", this, cacheEntry); } diff --git a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java index 438315ed7a..a8bac4e5e8 100644 --- a/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java +++ b/ratis-server/src/main/java/org/apache/ratis/server/impl/RetryCacheImpl.java @@ -18,14 +18,18 @@ package org.apache.ratis.server.impl; import org.apache.ratis.conf.RaftProperties; +import org.apache.ratis.protocol.ClientId; import org.apache.ratis.protocol.ClientInvocationId; import org.apache.ratis.protocol.RaftClientReply; +import org.apache.ratis.protocol.RaftClientRequest; import org.apache.ratis.server.RaftServerConfigKeys; import org.apache.ratis.server.RetryCache; import org.apache.ratis.thirdparty.com.google.common.cache.Cache; import org.apache.ratis.thirdparty.com.google.common.cache.CacheBuilder; import org.apache.ratis.thirdparty.com.google.common.cache.CacheStats; +import org.apache.ratis.util.CollectionUtils; import org.apache.ratis.util.JavaUtils; +import org.apache.ratis.util.MemoizedSupplier; import org.apache.ratis.util.TimeDuration; import org.apache.ratis.util.Timestamp; @@ -33,6 +37,7 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; class RetryCacheImpl implements RetryCache { static class CacheEntry implements Entry { @@ -181,13 +186,15 @@ public String toString() { } CacheEntry getOrCreateEntry(ClientInvocationId key) { - final CacheEntry entry; + return getOrCreateEntry(key, () -> new CacheEntry(key)); + } + + private CacheEntry getOrCreateEntry(ClientInvocationId key, Supplier constructor) { try { - entry = cache.get(key, () -> new CacheEntry(key)); + return cache.get(key, constructor::get); } catch (ExecutionException e) { - throw new IllegalStateException(e); + throw new IllegalStateException("Failed to get " + key, e); } - return entry; } CacheEntry refreshEntry(CacheEntry newEntry) { @@ -195,16 +202,11 @@ CacheEntry refreshEntry(CacheEntry newEntry) { return newEntry; } - CacheQueryResult queryCache(ClientInvocationId key) { - final CacheEntry newEntry = new CacheEntry(key); - final CacheEntry cacheEntry; - try { - cacheEntry = cache.get(key, () -> newEntry); - } catch (ExecutionException e) { - throw new IllegalStateException(e); - } - - if (cacheEntry == newEntry) { + CacheQueryResult queryCache(RaftClientRequest request) { + final ClientInvocationId key = ClientInvocationId.valueOf(request); + final MemoizedSupplier newEntry = MemoizedSupplier.valueOf(() -> new CacheEntry(key)); + final CacheEntry cacheEntry = getOrCreateEntry(key, newEntry); + if (newEntry.isInitialized()) { // this is the entry we just newly created return new CacheQueryResult(cacheEntry, false); } else if (!cacheEntry.isDone() || !cacheEntry.isFailed()){ @@ -221,13 +223,24 @@ CacheQueryResult queryCache(ClientInvocationId key) { if (currentEntry == cacheEntry || currentEntry == null) { // if the failed entry has not got replaced by another retry, or the // failed entry got invalidated, we add a new cache entry - return new CacheQueryResult(refreshEntry(newEntry), false); + return new CacheQueryResult(refreshEntry(newEntry.get()), false); } else { return new CacheQueryResult(currentEntry, true); } } } + void invalidateRepliedRequests(RaftClientRequest request) { + final ClientId clientId = request.getClientId(); + final Iterable callIds = request.getRepliedCallIds(); + if (!callIds.iterator().hasNext()) { + return; + } + + LOG.debug("invalidateRepliedRequests callIds {} for {}", callIds, clientId); + cache.invalidateAll(CollectionUtils.as(callIds, callId -> ClientInvocationId.valueOf(clientId, callId))); + } + @Override public Statistics getStatistics() { return statistics.updateAndGet(old -> old == null || old.isExpired()? new StatisticsImpl(cache): old); @@ -240,10 +253,8 @@ public CacheEntry getIfPresent(ClientInvocationId key) { @Override public synchronized void close() { - if (cache != null) { - cache.invalidateAll(); - statistics.set(null); - } + cache.invalidateAll(); + statistics.set(null); } static CompletableFuture failWithReply( diff --git a/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java b/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java index f729dcd2d5..288aa71a91 100644 --- a/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java +++ b/ratis-server/src/test/java/org/apache/ratis/RetryCacheTests.java @@ -82,11 +82,10 @@ void runTestBasicRetry(CLUSTER cluster) throws Exception { } } - public static RaftClient assertReply(RaftClientReply reply, RaftClient client, long callId) { + public static void assertReply(RaftClientReply reply, RaftClient client, long callId) { Assert.assertEquals(client.getId(), reply.getClientId()); Assert.assertEquals(callId, reply.getCallId()); Assert.assertTrue(reply.isSuccess()); - return client; } public void assertServer(MiniRaftCluster cluster, ClientId clientId, long callId, long oldLastApplied) throws Exception { diff --git a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java b/ratis-server/src/test/java/org/apache/ratis/client/impl/RaftClientTestUtil.java similarity index 90% rename from ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java rename to ratis-server/src/test/java/org/apache/ratis/client/impl/RaftClientTestUtil.java index ba00b8f00d..d90b0cc53f 100644 --- a/ratis-client/src/main/java/org/apache/ratis/client/impl/RaftClientTestUtil.java +++ b/ratis-server/src/test/java/org/apache/ratis/client/impl/RaftClientTestUtil.java @@ -21,6 +21,7 @@ import org.apache.ratis.proto.RaftProtos.SlidingWindowEntry; import org.apache.ratis.protocol.ClientInvocationId; import org.apache.ratis.protocol.Message; +import org.apache.ratis.protocol.RaftClientReply; import org.apache.ratis.protocol.RaftClientRequest; import org.apache.ratis.protocol.RaftPeerId; import org.apache.ratis.rpc.CallId; @@ -39,4 +40,8 @@ static RaftClientRequest newRaftClientRequest(RaftClient client, RaftPeerId serv long callId, Message message, RaftClientRequest.Type type, SlidingWindowEntry slidingWindowEntry) { return ((RaftClientImpl)client).newRaftClientRequest(server, callId, message, type, slidingWindowEntry); } + + static void handleReply(RaftClientRequest request, RaftClientReply reply, RaftClient client) { + ((RaftClientImpl)client).handleReply(request, reply); + } } diff --git a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java index 2ce5643ccf..312c9508d3 100644 --- a/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java +++ b/ratis-server/src/test/java/org/apache/ratis/statemachine/impl/SimpleStateMachine4Testing.java @@ -126,7 +126,7 @@ CompletableFuture collect(Type type, T value) { static class Blocking { enum Type { - START_TRANSACTION, READ_STATE_MACHINE_DATA, WRITE_STATE_MACHINE_DATA, FLUSH_STATE_MACHINE_DATA + START_TRANSACTION, APPLY_TRANSACTION, READ_STATE_MACHINE_DATA, WRITE_STATE_MACHINE_DATA, FLUSH_STATE_MACHINE_DATA } private final EnumMap> maps = new EnumMap<>(Type.class); @@ -243,7 +243,10 @@ public synchronized void reinitialize() throws IOException { @Override public CompletableFuture applyTransaction(TransactionContext trx) { + blocking.await(Blocking.Type.APPLY_TRANSACTION); LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry()); + LOG.info("applyTransaction for log index {}", entry.getIndex()); + put(entry); updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex()); @@ -386,6 +389,13 @@ public void unblockStartTransaction() { blocking.unblock(Blocking.Type.START_TRANSACTION); } + public void blockApplyTransaction() { + blocking.block(Blocking.Type.APPLY_TRANSACTION); + } + public void unblockApplyTransaction() { + blocking.unblock(Blocking.Type.APPLY_TRANSACTION); + } + public void blockWriteStateMachineData() { blocking.block(Blocking.Type.WRITE_STATE_MACHINE_DATA); } diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java index 7de1c40428..0af1d87cce 100644 --- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java +++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRaftServerWithGrpc.java @@ -23,6 +23,8 @@ import static org.apache.ratis.server.metrics.RaftServerMetricsImpl.RAFT_CLIENT_WRITE_REQUEST; import org.apache.ratis.metrics.RatisMetricRegistry; +import org.apache.ratis.protocol.ClientInvocationId; +import org.apache.ratis.server.RetryCache; import org.apache.ratis.util.JavaUtils; import org.slf4j.event.Level; import org.apache.ratis.conf.Parameters; @@ -183,19 +185,31 @@ void runTestLeaderRestart(MiniRaftClusterWithGrpc cluster) throws Exception { final RaftClientRpc rpc = client.getClientRpc(); final AtomicLong seqNum = new AtomicLong(); + final ClientInvocationId invocationId; { // send a request using rpc directly - final RaftClientRequest request = newRaftClientRequest(client, leader.getId(), seqNum.incrementAndGet()); + final RaftClientRequest request = newRaftClientRequest(client, seqNum.incrementAndGet()); + Assert.assertEquals(client.getId(), request.getClientId()); final CompletableFuture f = rpc.sendRequestAsync(request); - Assert.assertTrue(f.get().isSuccess()); + final RaftClientReply reply = f.get(); + Assert.assertTrue(reply.isSuccess()); + RaftClientTestUtil.handleReply(request, reply, client); + invocationId = ClientInvocationId.valueOf(request.getClientId(), request.getCallId()); + final RetryCache.Entry entry = leader.getRetryCache().getIfPresent(invocationId); + Assert.assertNotNull(entry); + LOG.info("cache entry {}", entry); } // send another request which will be blocked final SimpleStateMachine4Testing stateMachine = SimpleStateMachine4Testing.get(leader); stateMachine.blockStartTransaction(); - final RaftClientRequest requestBlocked = newRaftClientRequest(client, leader.getId(), seqNum.incrementAndGet()); + final RaftClientRequest requestBlocked = newRaftClientRequest(client, seqNum.incrementAndGet()); final CompletableFuture futureBlocked = rpc.sendRequestAsync(requestBlocked); + JavaUtils.attempt(() -> Assert.assertNull(leader.getRetryCache().getIfPresent(invocationId)), + 10, HUNDRED_MILLIS, "invalidate cache entry", LOG); + LOG.info("cache entry not found for {}", invocationId); + // change leader RaftTestUtil.changeLeader(cluster, leader.getId()); Assert.assertNotEquals(RaftPeerRole.LEADER, leader.getInfo().getCurrentRole()); @@ -206,7 +220,7 @@ void runTestLeaderRestart(MiniRaftClusterWithGrpc cluster) throws Exception { stateMachine.unblockStartTransaction(); // send one more request which should timeout. - final RaftClientRequest requestTimeout = newRaftClientRequest(client, leader.getId(), seqNum.incrementAndGet()); + final RaftClientRequest requestTimeout = newRaftClientRequest(client, seqNum.incrementAndGet()); rpc.handleException(leader.getId(), new Exception(), true); final CompletableFuture f = rpc.sendRequestAsync(requestTimeout); testFailureCase("request should timeout", f::get, @@ -346,9 +360,9 @@ void testRaftClientRequestMetrics(MiniRaftClusterWithGrpc cluster) throws IOExce } } - static RaftClientRequest newRaftClientRequest(RaftClient client, RaftPeerId serverId, long seqNum) { + static RaftClientRequest newRaftClientRequest(RaftClient client, long seqNum) { final SimpleMessage m = new SimpleMessage("m" + seqNum); - return RaftClientTestUtil.newRaftClientRequest(client, serverId, seqNum, m, + return RaftClientTestUtil.newRaftClientRequest(client, null, seqNum, m, RaftClientRequest.writeRequestType(), ProtoUtils.toSlidingWindowEntry(seqNum, seqNum == 1L)); } diff --git a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java index 400e6e5a64..8a1878cd8d 100644 --- a/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java +++ b/ratis-test/src/test/java/org/apache/ratis/grpc/TestRetryCacheWithGrpc.java @@ -17,8 +17,12 @@ */ package org.apache.ratis.grpc; +import org.apache.ratis.client.RaftClient; +import org.apache.ratis.proto.RaftProtos; +import org.apache.ratis.server.RetryCache; import org.apache.ratis.server.impl.MiniRaftCluster; import org.apache.ratis.RaftTestUtil; +import org.apache.ratis.RaftTestUtil.SimpleMessage; import org.apache.ratis.RetryCacheTests; import org.apache.ratis.conf.RaftProperties; import org.apache.ratis.protocol.ClientId; @@ -30,15 +34,131 @@ import org.apache.ratis.server.impl.RetryCacheTestUtil; import org.apache.ratis.statemachine.impl.SimpleStateMachine4Testing; import org.apache.ratis.statemachine.StateMachine; +import org.apache.ratis.util.Slf4jUtils; +import org.junit.Assert; import org.junit.Test; +import org.slf4j.event.Level; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; public class TestRetryCacheWithGrpc extends RetryCacheTests implements MiniRaftClusterWithGrpc.FactoryGet { + { + Slf4jUtils.setLogLevel(RetryCache.LOG, Level.TRACE); + + getProperties().setClass(MiniRaftCluster.STATEMACHINE_CLASS_KEY, + SimpleStateMachine4Testing.class, StateMachine.class); + } + + @Test + public void testInvalidateRepliedCalls() throws Exception { + runWithNewCluster(3, cluster -> new InvalidateRepliedCallsTest(cluster).run()); + } + + static long assertReply(RaftClientReply reply) { + Assert.assertTrue(reply.isSuccess()); + return reply.getCallId(); + } + + class InvalidateRepliedCallsTest { + private final MiniRaftCluster cluster; + private final RaftServer.Division leader; + private final AtomicInteger count = new AtomicInteger(); + + InvalidateRepliedCallsTest(MiniRaftCluster cluster) throws Exception { + this.cluster = cluster; + this.leader = RaftTestUtil.waitForLeader(cluster); + } + + SimpleMessage nextMessage() { + return new SimpleMessage("m" + count.incrementAndGet()); + } + + void assertRetryCacheEntry(RaftClient client, long callId, boolean exist) { + final RetryCache.Entry e = RetryCacheTestUtil.get(leader, client.getId(), callId); + if (exist) { + Assert.assertNotNull(e); + } else { + Assert.assertNull(e); + } + } + + long send(RaftClient client, Long previousCallId) throws Exception { + final RaftClientReply reply = client.io().send(nextMessage()); + final long callId = assertReply(reply); + if (previousCallId != null) { + // the previous should be invalidated. + assertRetryCacheEntry(client, previousCallId, false); + } + // the current should exist. + assertRetryCacheEntry(client, callId, true); + return callId; + } + + CompletableFuture sendAsync(RaftClient client) { + return client.async().send(nextMessage()) + .thenApply(TestRetryCacheWithGrpc::assertReply); + } + + CompletableFuture watch(long logIndex, RaftClient client) { + return client.async().watch(logIndex, RaftProtos.ReplicationLevel.MAJORITY) + .thenApply(TestRetryCacheWithGrpc::assertReply); + } + + void run() throws Exception { + try (RaftClient client = cluster.createClient()) { + // test blocking io + Long lastBlockingCall = null; + for (int i = 0; i < 5; i++) { + lastBlockingCall = send(client, lastBlockingCall); + } + final long lastBlockingCallId = lastBlockingCall; + + // test async + final SimpleStateMachine4Testing stateMachine = SimpleStateMachine4Testing.get(leader); + stateMachine.blockApplyTransaction(); + final List> asyncCalls = new ArrayList<>(); + for (int i = 0; i < 5; i++) { + // Since applyTransaction is blocked, the replied call id remains the same. + asyncCalls.add(sendAsync(client)); + } + // async call will invalidate blocking calls even if applyTransaction is blocked. + assertRetryCacheEntry(client, lastBlockingCallId, false); + + ONE_SECOND.sleep(); + // No calls can be completed. + for (CompletableFuture f : asyncCalls) { + Assert.assertFalse(f.isDone()); + } + stateMachine.unblockApplyTransaction(); + // No calls can be invalidated. + for (CompletableFuture f : asyncCalls) { + assertRetryCacheEntry(client, f.join(), true); + } + + // one more blocking call will invalidate all async calls + final long oneMoreBlockingCall = send(client, null); + LOG.info("oneMoreBlockingCall callId={}", oneMoreBlockingCall); + assertRetryCacheEntry(client, oneMoreBlockingCall, true); + for (CompletableFuture f : asyncCalls) { + assertRetryCacheEntry(client, f.join(), false); + } + + // watch call will invalidate blocking calls + final long watchAsyncCall = watch(1, client).get(); + LOG.info("watchAsyncCall callId={}", watchAsyncCall); + assertRetryCacheEntry(client, oneMoreBlockingCall, false); + // retry cache should not contain watch calls + assertRetryCacheEntry(client, watchAsyncCall, false); + } + } + } @Test(timeout = 10000) public void testRetryOnResourceUnavailableException()