Skip to content

Commit

Permalink
Stop assuming input and output of invocations implement MessageLite (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper authored Feb 8, 2024
1 parent 530035e commit f870d57
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package dev.restate.sdk.common.syscalls;

import dev.restate.sdk.common.TerminalException;
import java.util.function.Function;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -47,6 +48,29 @@ private Result() {}
@Nullable
public abstract TerminalException getFailure();

// --- Helper methods

/**
* Map this result success value. If the mapper throws an exception, this exception will be
* converted to {@link TerminalException} and return a new failed {@link Result}.
*/
public <U> Result<U> mapSuccess(Function<T, U> mapper) {
if (this.isSuccess()) {
try {
return Result.success(mapper.apply(this.getValue()));
} catch (TerminalException e) {
return Result.failure(e);
} catch (Exception e) {
return Result.failure(
new TerminalException(TerminalException.Code.UNKNOWN, e.getMessage()));
}
}
//noinspection unchecked
return (Result<U>) this;
}

// --- Factory methods

@SuppressWarnings("unchecked")
public static <T> Result<T> empty() {
return (Result<T>) Empty.INSTANCE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package dev.restate.sdk.common.syscalls;

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import dev.restate.sdk.common.InvocationId;
import dev.restate.sdk.common.TerminalException;
import io.grpc.Context;
Expand All @@ -18,7 +17,6 @@
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Function;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -53,10 +51,9 @@ static Syscalls current() {
// Note: These are not supposed to be exposed to RestateContext, but they should be used through
// gRPC APIs.

<T extends MessageLite> void pollInput(
Function<ByteString, T> mapper, SyscallCallback<Deferred<T>> callback);
void pollInput(SyscallCallback<Deferred<ByteString>> callback);

<T extends MessageLite> void writeOutput(T value, SyscallCallback<Void> callback);
void writeOutput(ByteString value, SyscallCallback<Void> callback);

void writeOutput(TerminalException exception, SyscallCallback<Void> callback);

Expand Down
18 changes: 8 additions & 10 deletions sdk-core/src/main/java/dev/restate/sdk/core/Entries.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,12 @@ void updateUserStateStorageWithCompletion(
E expected, CompletionMessage actual, UserStateStore userStateStore) {}
}

static final class PollInputEntry<R extends MessageLite>
extends CompletableJournalEntry<PollInputStreamEntryMessage, R> {
static final class PollInputEntry
extends CompletableJournalEntry<PollInputStreamEntryMessage, ByteString> {

private final Function<ByteString, Result<R>> valueParser;
static final PollInputEntry INSTANCE = new PollInputEntry();

PollInputEntry(Function<ByteString, Result<R>> valueParser) {
this.valueParser = valueParser;
}
private PollInputEntry() {}

@Override
public void trace(PollInputStreamEntryMessage expected, Span span) {
Expand All @@ -69,9 +67,9 @@ public boolean hasResult(PollInputStreamEntryMessage actual) {
}

@Override
public Result<R> parseEntryResult(PollInputStreamEntryMessage actual) {
public Result<ByteString> parseEntryResult(PollInputStreamEntryMessage actual) {
if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.VALUE) {
return valueParser.apply(actual.getValue());
return Result.success(actual.getValue());
} else if (actual.getResultCase() == PollInputStreamEntryMessage.ResultCase.FAILURE) {
return Result.failure(Util.toRestateException(actual.getFailure()));
} else {
Expand All @@ -80,9 +78,9 @@ public Result<R> parseEntryResult(PollInputStreamEntryMessage actual) {
}

@Override
public Result<R> parseCompletionResult(CompletionMessage actual) {
public Result<ByteString> parseCompletionResult(CompletionMessage actual) {
if (actual.getResultCase() == CompletionMessage.ResultCase.VALUE) {
return valueParser.apply(actual.getValue());
return Result.success(actual.getValue());
} else if (actual.getResultCase() == CompletionMessage.ResultCase.FAILURE) {
return Result.failure(Util.toRestateException(actual.getFailure()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package dev.restate.sdk.core;

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import dev.restate.sdk.common.InvocationId;
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.common.syscalls.Deferred;
Expand All @@ -20,7 +19,6 @@
import java.time.Duration;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Function;

class ExecutorSwitchingSyscalls implements SyscallsInternal {

Expand All @@ -33,13 +31,12 @@ class ExecutorSwitchingSyscalls implements SyscallsInternal {
}

@Override
public <T extends MessageLite> void pollInput(
Function<ByteString, T> mapper, SyscallCallback<Deferred<T>> callback) {
syscallsExecutor.execute(() -> syscalls.pollInput(mapper, callback));
public void pollInput(SyscallCallback<Deferred<ByteString>> callback) {
syscallsExecutor.execute(() -> syscalls.pollInput(callback));
}

@Override
public <T extends MessageLite> void writeOutput(T value, SyscallCallback<Void> callback) {
public void writeOutput(ByteString value, SyscallCallback<Void> callback) {
syscallsExecutor.execute(() -> syscalls.writeOutput(value, callback));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,36 +8,37 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

import com.google.protobuf.MessageLite;
import com.google.protobuf.ByteString;
import dev.restate.sdk.common.InvocationId;
import dev.restate.sdk.common.TerminalException;
import dev.restate.sdk.common.syscalls.SyscallCallback;
import dev.restate.sdk.common.syscalls.Syscalls;
import io.grpc.*;
import java.io.IOException;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import javax.annotation.Nullable;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

class GrpcUnaryRpcHandler implements RpcHandler {
class GrpcUnaryRpcHandler<Req, Res> implements RpcHandler {

private static final Logger LOG = LogManager.getLogger(GrpcUnaryRpcHandler.class);

private final SyscallsInternal syscalls;
private final RestateServerCallListener<MessageLite> restateListener;
private final RestateServerCallListener<Req> restateListener;
private final CompletableFuture<Void> serverCallReady;
private final MethodDescriptor<MessageLite, MessageLite> methodDescriptor;
private final MethodDescriptor<Req, Res> methodDescriptor;

GrpcUnaryRpcHandler(
ServerMethodDefinition<MessageLite, MessageLite> method,
ServerMethodDefinition<Req, Res> method,
SyscallsInternal syscalls,
@Nullable Executor userCodeExecutor) {
this.syscalls = syscalls;
this.methodDescriptor = method.getMethodDescriptor();
this.serverCallReady = new CompletableFuture<>();
RestateServerCall serverCall =
new RestateServerCall(this.methodDescriptor, this.syscalls, this.serverCallReady);
RestateServerCall<Req, Res> serverCall =
new RestateServerCall<>(method.getMethodDescriptor(), this.syscalls, this.serverCallReady);

// This gRPC context will be propagated to the user thread.
// Note: from now on we cannot modify this context anymore!
Expand All @@ -47,13 +48,13 @@ class GrpcUnaryRpcHandler implements RpcHandler {
.withValue(Syscalls.SYSCALLS_KEY, this.syscalls);

// Create the listener
RestateServerCallListener<MessageLite> listener =
RestateServerCallListener<Req> listener =
new GrpcServerCallListenerAdaptor<>(
context, serverCall, new Metadata(), method.getServerCallHandler());

// Wrap in the executor switcher, if needed
if (userCodeExecutor != null) {
listener = new ExecutorSwitchingServerCallListener(listener, userCodeExecutor);
listener = new ExecutorSwitchingServerCallListener<>(listener, userCodeExecutor);
}

this.restateListener = listener;
Expand All @@ -69,7 +70,7 @@ public void start() {
SyscallCallback.of(
pollInputReadyResult -> {
if (pollInputReadyResult.isSuccess()) {
final MessageLite message = pollInputReadyResult.getValue();
final Req message = pollInputReadyResult.getValue();
LOG.trace("Read input message:\n{}", message);

// In theory, we never need this, as once we reach this point of the code the server
Expand Down Expand Up @@ -198,20 +199,20 @@ private void closeWithException(Throwable e) {
}
}

private static class ExecutorSwitchingServerCallListener
implements RestateServerCallListener<MessageLite> {
private static class ExecutorSwitchingServerCallListener<Req>
implements RestateServerCallListener<Req> {

private final RestateServerCallListener<MessageLite> listener;
private final RestateServerCallListener<Req> listener;
private final Executor userExecutor;

private ExecutorSwitchingServerCallListener(
RestateServerCallListener<MessageLite> listener, Executor userExecutor) {
RestateServerCallListener<Req> listener, Executor userExecutor) {
this.listener = listener;
this.userExecutor = userExecutor;
}

@Override
public void invoke(MessageLite message) {
public void invoke(Req message) {
userExecutor.execute(() -> listener.invoke(message));
}

Expand Down Expand Up @@ -254,17 +255,17 @@ public void ready() {
* <li>Trampolining back to state machine executor is provided by the syscalls wrapper.
* </ul>
*/
static class RestateServerCall extends ServerCall<MessageLite, MessageLite> {
static class RestateServerCall<Req, Res> extends ServerCall<Req, Res> {

private final MethodDescriptor<MessageLite, MessageLite> methodDescriptor;
private final MethodDescriptor<Req, Res> methodDescriptor;
private final SyscallsInternal syscalls;

// This variable don't need to be volatile because it's accessed only by #request()
private int inputPollRequests = 0;
private final CompletableFuture<Void> serverCallReady;

RestateServerCall(
MethodDescriptor<MessageLite, MessageLite> methodDescriptor,
MethodDescriptor<Req, Res> methodDescriptor,
SyscallsInternal syscalls,
CompletableFuture<Void> serverCallReady) {
this.methodDescriptor = methodDescriptor;
Expand Down Expand Up @@ -308,9 +309,17 @@ public void sendHeaders(Metadata headers) {
}

@Override
public void sendMessage(MessageLite message) {
public void sendMessage(Res message) {
ByteString output;
try {
output = ByteString.readFrom(methodDescriptor.streamResponse(message));
} catch (IOException e) {
syscalls.fail(e);
return;
}

syscalls.writeOutput(
message,
output,
SyscallCallback.ofVoid(
() -> LOG.trace("Wrote output message:\n{}", message), syscalls::fail));
}
Expand Down Expand Up @@ -346,7 +355,7 @@ public boolean isCancelled() {
}

@Override
public MethodDescriptor<MessageLite, MessageLite> getMethodDescriptor() {
public MethodDescriptor<Req, Res> getMethodDescriptor() {
return methodDescriptor;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

import com.google.protobuf.MessageLite;
import dev.restate.generated.service.discovery.Discovery;
import dev.restate.sdk.common.ServiceAdapter;
import dev.restate.sdk.common.ServicesBundle;
Expand Down Expand Up @@ -64,9 +63,7 @@ public InvocationHandler resolve(
throw ProtocolException.methodNotFound(serviceName, methodName);
}
String fullyQualifiedServiceMethod = serviceName + "/" + methodName;
ServerMethodDefinition<MessageLite, MessageLite> method =
(ServerMethodDefinition<MessageLite, MessageLite>)
svc.getMethod(fullyQualifiedServiceMethod);
ServerMethodDefinition<?, ?> method = svc.getMethod(fullyQualifiedServiceMethod);
if (method == null) {
throw ProtocolException.methodNotFound(serviceName, methodName);
}
Expand Down
13 changes: 4 additions & 9 deletions sdk-core/src/main/java/dev/restate/sdk/core/SyscallsImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package dev.restate.sdk.core;

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import com.google.rpc.Code;
import dev.restate.generated.sdk.java.Java;
import dev.restate.generated.service.protocol.Protocol;
Expand Down Expand Up @@ -48,27 +47,23 @@ public InvocationId invocationId() {
}

@Override
public <T extends MessageLite> void pollInput(
Function<ByteString, T> mapper, SyscallCallback<Deferred<T>> callback) {
public void pollInput(SyscallCallback<Deferred<ByteString>> callback) {
wrapAndPropagateExceptions(
() -> {
LOG.trace("pollInput");
this.stateMachine.processCompletableJournalEntry(
PollInputStreamEntryMessage.getDefaultInstance(),
new PollInputEntry<>(protoDeserializer(mapper)),
callback);
PollInputStreamEntryMessage.getDefaultInstance(), PollInputEntry.INSTANCE, callback);
},
callback);
}

@Override
public <T extends MessageLite> void writeOutput(T value, SyscallCallback<Void> callback) {
public void writeOutput(ByteString value, SyscallCallback<Void> callback) {
wrapAndPropagateExceptions(
() -> {
LOG.trace("writeOutput success");
this.writeOutput(
Protocol.OutputStreamEntryMessage.newBuilder().setValue(value.toByteString()).build(),
callback);
Protocol.OutputStreamEntryMessage.newBuilder().setValue(value).build(), callback);
},
callback);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
package dev.restate.sdk.core;

import com.google.protobuf.ByteString;
import com.google.protobuf.MessageLite;
import dev.restate.sdk.common.syscalls.Deferred;
import dev.restate.sdk.common.syscalls.Result;
import dev.restate.sdk.common.syscalls.SyscallCallback;
Expand All @@ -35,16 +34,16 @@ default Deferred<Void> createAllDeferred(List<Deferred<?>> children) {

// -- Helper for pollInput

default <T extends MessageLite> void pollInputAndResolve(
default <T> void pollInputAndResolve(
Function<ByteString, T> mapper, SyscallCallback<Result<T>> callback) {
this.pollInput(
mapper,
SyscallCallback.of(
deferredValue ->
this.resolveDeferred(
deferredValue,
SyscallCallback.ofVoid(
() -> callback.onSuccess(deferredValue.toResult()), callback::onCancel)),
() -> callback.onSuccess(deferredValue.toResult().mapSuccess(mapper)),
callback::onCancel)),
callback::onCancel));
}

Expand Down

0 comments on commit f870d57

Please sign in to comment.