Skip to content

Commit

Permalink
KeyedContext#clearAll (#217)
Browse files Browse the repository at this point in the history
* Add `ctx.clearAll()`
* Use `clearAll()` in WorkflowManager
  • Loading branch information
slinkydeveloper authored Feb 8, 2024
1 parent 2635559 commit dd79547
Show file tree
Hide file tree
Showing 16 changed files with 148 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@ internal class ContextImpl internal constructor(private val syscalls: Syscalls)
}
}

override suspend fun clearAll() {
return suspendCancellableCoroutine { cont: CancellableContinuation<Unit> ->
syscalls.clearAll(completingUnitContinuation(cont))
}
}

override suspend fun timer(duration: Duration): Awaitable<Unit> {
val deferred: Deferred<Void> =
suspendCancellableCoroutine { cont: CancellableContinuation<Deferred<Void>> ->
Expand Down
3 changes: 3 additions & 0 deletions sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,9 @@ sealed interface KeyedContext : UnkeyedContext {
*/
suspend fun clear(key: StateKey<*>)

/** Clears all the state of this service instance key-value state storage */
suspend fun clearAll()

companion object {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,7 @@ package dev.restate.sdk.kotlin
import dev.restate.sdk.common.CoreSerdes
import dev.restate.sdk.common.StateKey
import dev.restate.sdk.core.EagerStateTestSuite
import dev.restate.sdk.core.testservices.GreeterGrpcKt
import dev.restate.sdk.core.testservices.GreetingRequest
import dev.restate.sdk.core.testservices.GreetingResponse
import dev.restate.sdk.core.testservices.greetingResponse
import dev.restate.sdk.core.testservices.*
import io.grpc.BindableService
import kotlinx.coroutines.Dispatchers
import org.assertj.core.api.AssertionsForClassTypes.assertThat
Expand Down Expand Up @@ -75,4 +72,22 @@ class EagerStateTest : EagerStateTestSuite() {
override fun getClearAndGet(): BindableService {
return GetClearAndGet()
}

private class GetClearAllAndGet : GreeterRestateKt.GreeterRestateKtImplBase() {
override suspend fun greet(context: KeyedContext, request: GreetingRequest): GreetingResponse {
val ctx = KeyedContext.current()
val oldState = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))!!

ctx.clearAll()

assertThat(ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))).isNull()
assertThat(ctx.get(StateKey.of("ANOTHER_STATE", CoreSerdes.JSON_STRING))).isNull()

return greetingResponse { message = oldState }
}
}

override fun getClearAllAndGet(): BindableService {
return GetClearAllAndGet()
}
}
5 changes: 5 additions & 0 deletions sdk-api/src/main/java/dev/restate/sdk/ContextImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ public void clear(StateKey<?> key) {
Util.<Void>blockOnSyscall(cb -> syscalls.clear(key.name(), cb));
}

@Override
public void clearAll() {
Util.<Void>blockOnSyscall(syscalls::clearAll);
}

@Override
public <T> void set(StateKey<T> key, @Nonnull T value) {
Util.<Void>blockOnSyscall(
Expand Down
3 changes: 3 additions & 0 deletions sdk-api/src/main/java/dev/restate/sdk/KeyedContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ public interface KeyedContext extends UnkeyedContext {
*/
void clear(StateKey<?> key);

/** Clears all the state of this service instance key-value state storage */
void clearAll();

/**
* Sets the given value under the given key, serializing the value using the {@link Serde} in the
* {@link StateKey}.
Expand Down
22 changes: 22 additions & 0 deletions sdk-api/src/test/java/dev/restate/sdk/EagerStateTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -97,4 +97,26 @@ public void greet(GreetingRequest request, StreamObserver<GreetingResponse> resp
protected BindableService getClearAndGet() {
return new GetClearAndGet();
}

private static class GetClearAllAndGet extends GreeterGrpc.GreeterImplBase
implements RestateService {
@Override
public void greet(GreetingRequest request, StreamObserver<GreetingResponse> responseObserver) {
KeyedContext ctx = KeyedContext.current();

String oldState = ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING)).get();

ctx.clearAll();
assertThat(ctx.get(StateKey.of("STATE", CoreSerdes.JSON_STRING))).isEmpty();
assertThat(ctx.get(StateKey.of("ANOTHER_STATE", CoreSerdes.JSON_STRING))).isEmpty();

responseObserver.onNext(GreetingResponse.newBuilder().setMessage(oldState).build());
responseObserver.onCompleted();
}
}

@Override
protected BindableService getClearAllAndGet() {
return new GetClearAllAndGet();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ static Syscalls current() {

void clear(String name, SyscallCallback<Void> callback);

void clearAll(SyscallCallback<Void> callback);

void set(String name, ByteString value, SyscallCallback<Void> callback);

// ----- Syscalls
Expand Down
24 changes: 24 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/Entries.java
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,30 @@ void updateUserStateStoreWithEntry(
}
}

static final class ClearAllStateEntry extends JournalEntry<ClearAllStateEntryMessage> {

static final ClearAllStateEntry INSTANCE = new ClearAllStateEntry();

private ClearAllStateEntry() {}

@Override
public void trace(ClearAllStateEntryMessage expected, Span span) {
span.addEvent("ClearAllState");
}

@Override
void checkEntryHeader(ClearAllStateEntryMessage expected, MessageLite actual)
throws ProtocolException {
Util.assertEntryEquals(expected, actual);
}

@Override
void updateUserStateStoreWithEntry(
ClearAllStateEntryMessage expected, UserStateStore userStateStore) {
userStateStore.clearAll();
}
}

static final class SetStateEntry extends JournalEntry<SetStateEntryMessage> {

static final SetStateEntry INSTANCE = new SetStateEntry();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ public void clear(String name, SyscallCallback<Void> callback) {
syscallsExecutor.execute(() -> syscalls.clear(name, callback));
}

@Override
public void clearAll(SyscallCallback<Void> callback) {
syscallsExecutor.execute(() -> syscalls.clearAll(callback));
}

@Override
public void set(String name, ByteString value, SyscallCallback<Void> callback) {
syscallsExecutor.execute(() -> syscalls.set(name, value, callback));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ public static MessageHeader fromMessage(MessageLite msg) {
return new MessageHeader(MessageType.SetStateEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.ClearStateEntryMessage) {
return new MessageHeader(MessageType.ClearStateEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.ClearAllStateEntryMessage) {
return new MessageHeader(MessageType.ClearAllStateEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.SleepEntryMessage) {
return new MessageHeader(
MessageType.SleepEntryMessage,
Expand Down
8 changes: 8 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public enum MessageType {
GetStateEntryMessage,
SetStateEntryMessage,
ClearStateEntryMessage,
ClearAllStateEntryMessage,

// Syscalls
SleepEntryMessage,
Expand All @@ -52,6 +53,7 @@ public enum MessageType {
public static final short GET_STATE_ENTRY_MESSAGE_TYPE = 0x0800;
public static final short SET_STATE_ENTRY_MESSAGE_TYPE = 0x0801;
public static final short CLEAR_STATE_ENTRY_MESSAGE_TYPE = 0x0802;
public static final short CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE = 0x0803;
public static final short SLEEP_ENTRY_MESSAGE_TYPE = 0x0C00;
public static final short INVOKE_ENTRY_MESSAGE_TYPE = 0x0C01;
public static final short BACKGROUND_INVOKE_ENTRY_MESSAGE_TYPE = 0x0C02;
Expand Down Expand Up @@ -84,6 +86,8 @@ public Parser<? extends MessageLite> messageParser() {
return Protocol.SetStateEntryMessage.parser();
case ClearStateEntryMessage:
return Protocol.ClearStateEntryMessage.parser();
case ClearAllStateEntryMessage:
return Protocol.ClearAllStateEntryMessage.parser();
case SleepEntryMessage:
return Protocol.SleepEntryMessage.parser();
case InvokeEntryMessage:
Expand Down Expand Up @@ -126,6 +130,8 @@ public short encode() {
return SET_STATE_ENTRY_MESSAGE_TYPE;
case ClearStateEntryMessage:
return CLEAR_STATE_ENTRY_MESSAGE_TYPE;
case ClearAllStateEntryMessage:
return CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE;
case SleepEntryMessage:
return SLEEP_ENTRY_MESSAGE_TYPE;
case InvokeEntryMessage:
Expand Down Expand Up @@ -168,6 +174,8 @@ public static MessageType decode(short value) throws ProtocolException {
return SetStateEntryMessage;
case CLEAR_STATE_ENTRY_MESSAGE_TYPE:
return ClearStateEntryMessage;
case CLEAR_ALL_STATE_ENTRY_MESSAGE_TYPE:
return ClearAllStateEntryMessage;
case SLEEP_ENTRY_MESSAGE_TYPE:
return SleepEntryMessage;
case INVOKE_ENTRY_MESSAGE_TYPE:
Expand Down
13 changes: 13 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,19 @@ public void clear(String name, SyscallCallback<Void> callback) {
callback);
}

@Override
public void clearAll(SyscallCallback<Void> callback) {
wrapAndPropagateExceptions(
() -> {
LOG.trace("clearAll");
this.stateMachine.processJournalEntry(
Protocol.ClearAllStateEntryMessage.newBuilder().build(),
ClearAllStateEntry.INSTANCE,
callback);
},
callback);
}

@Override
public void set(String name, ByteString value, SyscallCallback<Void> callback) {
wrapAndPropagateExceptions(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public ByteString getValue() {
}
}

private final boolean isPartial;
private boolean isPartial;
private final HashMap<ByteString, State> map;

UserStateStore(boolean isPartial, Map<ByteString, ByteString> map) {
Expand All @@ -63,4 +63,9 @@ public void set(ByteString key, ByteString value) {
public void clear(ByteString key) {
this.map.put(key, Empty.INSTANCE);
}

public void clearAll() {
this.map.clear();
this.isPartial = false;
}
}
1 change: 1 addition & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/Util.java
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ static boolean isEntry(MessageLite msg) {
|| msg instanceof Protocol.GetStateEntryMessage
|| msg instanceof Protocol.SetStateEntryMessage
|| msg instanceof Protocol.ClearStateEntryMessage
|| msg instanceof Protocol.ClearAllStateEntryMessage
|| msg instanceof Protocol.SleepEntryMessage
|| msg instanceof Protocol.InvokeEntryMessage
|| msg instanceof Protocol.BackgroundInvokeEntryMessage
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import static org.assertj.core.api.AssertionsForClassTypes.entry;

import com.google.protobuf.MessageLite;
import dev.restate.generated.service.protocol.Protocol.ClearAllStateEntryMessage;
import dev.restate.sdk.core.testservices.GreeterGrpc;
import io.grpc.BindableService;
import java.util.Map;
Expand All @@ -28,7 +29,11 @@ public abstract class EagerStateTestSuite implements TestSuite {

protected abstract BindableService getClearAndGet();

protected abstract BindableService getClearAllAndGet();

private static final Map.Entry<String, String> STATE_FRANCESCO = entry("STATE", "Francesco");
private static final Map.Entry<String, String> ANOTHER_STATE_FRANCESCO =
entry("ANOTHER_STATE", "Francesco");
private static final MessageLite INPUT_TILL = inputMessage(greetingRequest("Till"));
private static final MessageLite GET_STATE_FRANCESCO = getStateMessage("STATE", "Francesco");
private static final MessageLite GET_STATE_FRANCESCO_TILL =
Expand Down Expand Up @@ -109,6 +114,29 @@ public Stream<TestDefinition> definitions() {
getStateEmptyMessage("STATE"),
OUTPUT_FRANCESCO,
END_MESSAGE)
.named("With partial state on the first get"),
testInvocation(this::getClearAllAndGet, GreeterGrpc.getGreetMethod())
.withInput(startMessage(1, STATE_FRANCESCO, ANOTHER_STATE_FRANCESCO), INPUT_TILL)
.expectingOutput(
GET_STATE_FRANCESCO,
ClearAllStateEntryMessage.getDefaultInstance(),
getStateEmptyMessage("STATE"),
getStateEmptyMessage("ANOTHER_STATE"),
OUTPUT_FRANCESCO,
END_MESSAGE)
.named("With state in the state_map"),
testInvocation(this::getClearAllAndGet, GreeterGrpc.getGreetMethod())
.withInput(
startMessage(1).setPartialState(true),
INPUT_TILL,
completionMessage(1, STATE_FRANCESCO.getValue()))
.expectingOutput(
getStateMessage("STATE"),
ClearAllStateEntryMessage.getDefaultInstance(),
getStateEmptyMessage("STATE"),
getStateEmptyMessage("ANOTHER_STATE"),
OUTPUT_FRANCESCO,
END_MESSAGE)
.named("With partial state on the first get"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ public void setOutput(KeyedContext context, SetOutputRequest request) throws Ter
@Override
public void cleanup(KeyedContext context, WorkflowManagerRequest request)
throws TerminalException {
// TODO could use https://github.com/restatedev/restate/issues/224
context.clearAll();
}

private StateKey<DurablePromiseCompletion> durablePromiseKey(String key) {
Expand Down

0 comments on commit dd79547

Please sign in to comment.