Skip to content

Commit

Permalink
RATIS-872. Invalidate replied calls in retry cache.
Browse files Browse the repository at this point in the history
  • Loading branch information
szetszwo committed Oct 25, 2023
1 parent 3e7f9e5 commit 8931a0a
Show file tree
Hide file tree
Showing 12 changed files with 267 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ static RaftRpcRequestProto.Builder toRaftRpcRequestProtoBuilder(RaftClientReques

return b.setCallId(request.getCallId())
.setToLeader(request.isToLeader())
.addAllRepliedCallIds(request.getRepliedCallIds())
.setTimeoutMs(request.getTimeoutMs());
}

Expand Down Expand Up @@ -192,6 +193,7 @@ static RaftClientRequest toRaftClientRequest(RaftClientRequestProto p) {
.setCallId(request.getCallId())
.setMessage(toMessage(p.getMessage()))
.setType(type)
.setRepliedCallIds(request.getRepliedCallIdsList())
.setRoutingTable(getRoutingTable(request))
.setTimeoutMs(request.getTimeoutMs())
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
Expand Down Expand Up @@ -127,13 +129,56 @@ void set(Collection<RaftPeer> newPeers) {
}
}

static class RepliedCallIds {
private final Object name;
/** The replied callIds. */
private Set<Long> replied = new TreeSet<>();
/**
* Map: callId to-be-sent -> replied callIds to-be-included.
* When retrying the same callId, the request will include the same set of replied callIds.
*
* @see RaftClientRequest#getRepliedCallIds()
*/
private final ConcurrentMap<Long, Set<Long>> sent = new ConcurrentHashMap<>();

RepliedCallIds(Object name) {
this.name = name;
}

/** The given callId is replied. */
void add(long repliedCallId) {
LOG.debug("{}: add replied callId {}", name, repliedCallId);
synchronized (this) {
// synchronized to avoid adding to a previous set.
replied.add(repliedCallId);
}
sent.remove(repliedCallId);
}

/** @return the replied callIds for the given callId. */
Iterable<Long> get(long callId) {
final Supplier<Set<Long>> supplier = MemoizedSupplier.valueOf(this::getAndReset);
final Set<Long> set = Collections.unmodifiableSet(sent.computeIfAbsent(callId, cid -> supplier.get()));
LOG.debug("{}: get {} returns {}", name, callId, set);
return set;
}

private synchronized Set<Long> getAndReset() {
final Set<Long> previous = replied;
replied = new TreeSet<>();
return previous;
}
}

private final ClientId clientId;
private final RaftClientRpc clientRpc;
private final RaftPeerList peers = new RaftPeerList();
private final RaftGroupId groupId;
private final RetryPolicy retryPolicy;

private volatile RaftPeerId leaderId;
/** The callIds of the replied requests. */
private final RepliedCallIds repliedCallIds;

private final TimeoutExecutor scheduler = TimeoutExecutor.getInstance();

Expand All @@ -158,6 +203,7 @@ void set(Collection<RaftPeer> newPeers) {

this.leaderId = Objects.requireNonNull(computeLeaderId(leaderId, group),
() -> "this.leaderId is set to null, leaderId=" + leaderId + ", group=" + group);
this.repliedCallIds = new RepliedCallIds(clientId);
this.retryPolicy = Objects.requireNonNull(retryPolicy, "retry policy can't be null");

clientRpc.addRaftPeers(group.getPeers());
Expand Down Expand Up @@ -241,7 +287,8 @@ RaftClientRequest newRaftClientRequest(
if (server != null) {
b.setServerId(server);
} else {
b.setLeaderId(getLeaderId());
b.setLeaderId(getLeaderId())
.setRepliedCallIds(repliedCallIds.get(callId));
}
return b.setClientId(clientId)
.setGroupId(groupId)
Expand Down Expand Up @@ -307,8 +354,14 @@ Throwable noMoreRetries(ClientRetryEvent event) {
}

RaftClientReply handleReply(RaftClientRequest request, RaftClientReply reply) {
if (request.isToLeader() && reply != null && reply.getException() == null) {
LEADER_CACHE.put(reply.getRaftGroupId(), reply.getServerId());
if (request.isToLeader() && reply != null) {
if (!request.getType().isReadOnly()) {
repliedCallIds.add(reply.getCallId());
}

if (reply.getException() == null) {
LEADER_CACHE.put(reply.getRaftGroupId(), reply.getServerId());
}
}
return reply;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@
import org.apache.ratis.util.Preconditions;
import org.apache.ratis.util.ProtoUtils;

import java.util.Collections;
import java.util.Objects;
import java.util.Optional;

import static org.apache.ratis.proto.RaftProtos.RaftClientRequestProto.TypeCase.*;

Expand Down Expand Up @@ -266,6 +268,7 @@ public static class Builder {
private RaftGroupId groupId;
private long callId;
private boolean toLeader;
private Iterable<Long> repliedCallIds = Collections.emptyList();

private Message message;
private Type type;
Expand Down Expand Up @@ -304,6 +307,11 @@ public Builder setCallId(long callId) {
return this;
}

public Builder setRepliedCallIds(Iterable<Long> repliedCallIds) {
this.repliedCallIds = repliedCallIds;
return this;
}

public Builder setMessage(Message message) {
this.message = message;
return this;
Expand Down Expand Up @@ -350,6 +358,7 @@ public static RaftClientRequest toWriteRequest(RaftClientRequest r, Message mess
private final Message message;
private final Type type;

private final Iterable<Long> repliedCallIds;
private final SlidingWindowEntry slidingWindowEntry;

private final RoutingTable routingTable;
Expand Down Expand Up @@ -386,8 +395,8 @@ private RaftClientRequest(Builder b) {

this.message = b.message;
this.type = b.type;
this.slidingWindowEntry = b.slidingWindowEntry != null ? b.slidingWindowEntry
: SlidingWindowEntry.getDefaultInstance();
this.repliedCallIds = Optional.ofNullable(b.repliedCallIds).orElseGet(Collections::emptyList);
this.slidingWindowEntry = b.slidingWindowEntry;
this.routingTable = b.routingTable;
this.timeoutMs = b.timeoutMs;
}
Expand All @@ -401,6 +410,10 @@ public boolean isToLeader() {
return toLeader;
}

public Iterable<Long> getRepliedCallIds() {
return repliedCallIds;
}

public SlidingWindowEntry getSlidingWindowEntry() {
return slidingWindowEntry;
}
Expand Down
4 changes: 2 additions & 2 deletions ratis-common/src/main/java/org/apache/ratis/rpc/CallId.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@

/**
* A long ID for RPC calls.
*
* <p>
* This class is threadsafe.
*/
public final class CallId {
private static final AtomicLong CALL_ID_COUNTER = new AtomicLong();
private static final AtomicLong CALL_ID_COUNTER = new AtomicLong(1);

private static final Comparator<Long> COMPARATOR = (left, right) -> {
final long diff = left - right;
Expand Down
1 change: 1 addition & 0 deletions ratis-proto/src/main/proto/Raft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ message RaftRpcRequestProto {
uint64 callId = 4;
bool toLeader = 5;

repeated uint64 repliedCallIds = 12; // The call ids of the replied requests
uint64 timeoutMs = 13;
RoutingTableProto routingTable = 14;
SlidingWindowEntry slidingWindowEntry = 15;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,8 @@ public CompletableFuture<RaftClientReply> submitClientRequestAsync(
}

private CompletableFuture<RaftClientReply> replyFuture(RaftClientRequest request) throws IOException {
retryCache.invalidateRepliedRequests(request);

final TypeCase type = request.getType().getTypeCase();
switch (type) {
case STALEREAD:
Expand All @@ -925,7 +927,7 @@ private CompletableFuture<RaftClientReply> writeAsync(RaftClientRequest request)
}

// query the retry cache
final RetryCacheImpl.CacheQueryResult queryResult = retryCache.queryCache(ClientInvocationId.valueOf(request));
final RetryCacheImpl.CacheQueryResult queryResult = retryCache.queryCache(request);
final CacheEntry cacheEntry = queryResult.getEntry();
if (queryResult.isRetry()) {
// return the cached future.
Expand Down Expand Up @@ -1784,7 +1786,7 @@ private CompletableFuture<Message> replyPendingRequest(
ClientInvocationId invocationId, long logIndex, CompletableFuture<Message> stateMachineFuture) {
// update the retry cache
final CacheEntry cacheEntry = retryCache.getOrCreateEntry(invocationId);
Preconditions.assertTrue(cacheEntry != null);
Objects.requireNonNull(cacheEntry , "cacheEntry == null");
if (getInfo().isLeader() && cacheEntry.isCompletedNormally()) {
LOG.warn("{} retry cache entry of leader should be pending: {}", this, cacheEntry);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,26 @@
package org.apache.ratis.server.impl;

import org.apache.ratis.conf.RaftProperties;
import org.apache.ratis.protocol.ClientId;
import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftClientRequest;
import org.apache.ratis.server.RaftServerConfigKeys;
import org.apache.ratis.server.RetryCache;
import org.apache.ratis.thirdparty.com.google.common.cache.Cache;
import org.apache.ratis.thirdparty.com.google.common.cache.CacheBuilder;
import org.apache.ratis.thirdparty.com.google.common.cache.CacheStats;
import org.apache.ratis.util.CollectionUtils;
import org.apache.ratis.util.JavaUtils;
import org.apache.ratis.util.MemoizedSupplier;
import org.apache.ratis.util.TimeDuration;
import org.apache.ratis.util.Timestamp;

import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;

class RetryCacheImpl implements RetryCache {
static class CacheEntry implements Entry {
Expand Down Expand Up @@ -181,30 +186,27 @@ public String toString() {
}

CacheEntry getOrCreateEntry(ClientInvocationId key) {
final CacheEntry entry;
return getOrCreateEntry(key, () -> new CacheEntry(key));
}

private CacheEntry getOrCreateEntry(ClientInvocationId key, Supplier<CacheEntry> constructor) {
try {
entry = cache.get(key, () -> new CacheEntry(key));
return cache.get(key, constructor::get);
} catch (ExecutionException e) {
throw new IllegalStateException(e);
throw new IllegalStateException("Failed to get " + key, e);
}
return entry;
}

CacheEntry refreshEntry(CacheEntry newEntry) {
cache.put(newEntry.getKey(), newEntry);
return newEntry;
}

CacheQueryResult queryCache(ClientInvocationId key) {
final CacheEntry newEntry = new CacheEntry(key);
final CacheEntry cacheEntry;
try {
cacheEntry = cache.get(key, () -> newEntry);
} catch (ExecutionException e) {
throw new IllegalStateException(e);
}

if (cacheEntry == newEntry) {
CacheQueryResult queryCache(RaftClientRequest request) {
final ClientInvocationId key = ClientInvocationId.valueOf(request);
final MemoizedSupplier<CacheEntry> newEntry = MemoizedSupplier.valueOf(() -> new CacheEntry(key));
final CacheEntry cacheEntry = getOrCreateEntry(key, newEntry);
if (newEntry.isInitialized()) {
// this is the entry we just newly created
return new CacheQueryResult(cacheEntry, false);
} else if (!cacheEntry.isDone() || !cacheEntry.isFailed()){
Expand All @@ -221,13 +223,24 @@ CacheQueryResult queryCache(ClientInvocationId key) {
if (currentEntry == cacheEntry || currentEntry == null) {
// if the failed entry has not got replaced by another retry, or the
// failed entry got invalidated, we add a new cache entry
return new CacheQueryResult(refreshEntry(newEntry), false);
return new CacheQueryResult(refreshEntry(newEntry.get()), false);
} else {
return new CacheQueryResult(currentEntry, true);
}
}
}

void invalidateRepliedRequests(RaftClientRequest request) {
final ClientId clientId = request.getClientId();
final Iterable<Long> callIds = request.getRepliedCallIds();
if (!callIds.iterator().hasNext()) {
return;
}

LOG.debug("invalidateRepliedRequests callIds {} for {}", callIds, clientId);
cache.invalidateAll(CollectionUtils.as(callIds, callId -> ClientInvocationId.valueOf(clientId, callId)));
}

@Override
public Statistics getStatistics() {
return statistics.updateAndGet(old -> old == null || old.isExpired()? new StatisticsImpl(cache): old);
Expand All @@ -240,10 +253,8 @@ public CacheEntry getIfPresent(ClientInvocationId key) {

@Override
public synchronized void close() {
if (cache != null) {
cache.invalidateAll();
statistics.set(null);
}
cache.invalidateAll();
statistics.set(null);
}

static CompletableFuture<RaftClientReply> failWithReply(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,10 @@ void runTestBasicRetry(CLUSTER cluster) throws Exception {
}
}

public static RaftClient assertReply(RaftClientReply reply, RaftClient client, long callId) {
public static void assertReply(RaftClientReply reply, RaftClient client, long callId) {
Assert.assertEquals(client.getId(), reply.getClientId());
Assert.assertEquals(callId, reply.getCallId());
Assert.assertTrue(reply.isSuccess());
return client;
}

public void assertServer(MiniRaftCluster cluster, ClientId clientId, long callId, long oldLastApplied) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.ratis.proto.RaftProtos.SlidingWindowEntry;
import org.apache.ratis.protocol.ClientInvocationId;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftClientReply;
import org.apache.ratis.protocol.RaftClientRequest;
import org.apache.ratis.protocol.RaftPeerId;
import org.apache.ratis.rpc.CallId;
Expand All @@ -39,4 +40,8 @@ static RaftClientRequest newRaftClientRequest(RaftClient client, RaftPeerId serv
long callId, Message message, RaftClientRequest.Type type, SlidingWindowEntry slidingWindowEntry) {
return ((RaftClientImpl)client).newRaftClientRequest(server, callId, message, type, slidingWindowEntry);
}

static void handleReply(RaftClientRequest request, RaftClientReply reply, RaftClient client) {
((RaftClientImpl)client).handleReply(request, reply);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ <T> CompletableFuture<T> collect(Type type, T value) {

static class Blocking {
enum Type {
START_TRANSACTION, READ_STATE_MACHINE_DATA, WRITE_STATE_MACHINE_DATA, FLUSH_STATE_MACHINE_DATA
START_TRANSACTION, APPLY_TRANSACTION, READ_STATE_MACHINE_DATA, WRITE_STATE_MACHINE_DATA, FLUSH_STATE_MACHINE_DATA
}

private final EnumMap<Type, CompletableFuture<Void>> maps = new EnumMap<>(Type.class);
Expand Down Expand Up @@ -243,7 +243,10 @@ public synchronized void reinitialize() throws IOException {

@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
blocking.await(Blocking.Type.APPLY_TRANSACTION);
LogEntryProto entry = Objects.requireNonNull(trx.getLogEntry());
LOG.info("applyTransaction for log index {}", entry.getIndex());

put(entry);
updateLastAppliedTermIndex(entry.getTerm(), entry.getIndex());

Expand Down Expand Up @@ -386,6 +389,13 @@ public void unblockStartTransaction() {
blocking.unblock(Blocking.Type.START_TRANSACTION);
}

public void blockApplyTransaction() {
blocking.block(Blocking.Type.APPLY_TRANSACTION);
}
public void unblockApplyTransaction() {
blocking.unblock(Blocking.Type.APPLY_TRANSACTION);
}

public void blockWriteStateMachineData() {
blocking.block(Blocking.Type.WRITE_STATE_MACHINE_DATA);
}
Expand Down
Loading

0 comments on commit 8931a0a

Please sign in to comment.