Skip to content

Commit

Permalink
Use the StartMessage.partial_state flag, Remove partial_state flag pa…
Browse files Browse the repository at this point in the history
…rsing
  • Loading branch information
slinkydeveloper committed Sep 6, 2023
1 parent 2a01b5e commit 0349058
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,10 @@ public void onSubscribe(Flow.Subscription subscription) {

@Override
public void onNext(InvocationFlow.InvocationInput invocationInput) {
MessageHeader header = invocationInput.header();
MessageLite msg = invocationInput.message();
LOG.trace("Received input message {} {}", msg.getClass(), msg);
if (this.state == State.WAITING_START) {
this.onStart(header, msg);
this.onStart(msg);
} else if (msg instanceof Protocol.CompletionMessage) {
// We check the instance rather than the state, because the user code might still be
// replaying, but the network layer is already past it and is receiving completions from the
Expand Down Expand Up @@ -154,7 +153,7 @@ void start(Consumer<InvocationId> afterStartCallback) {
this.inputSubscription.request(1);
}

void onStart(MessageHeader header, MessageLite msg) {
void onStart(MessageLite msg) {
if (!(msg instanceof Protocol.StartMessage)) {
this.fail(ProtocolException.unexpectedMessage(Protocol.StartMessage.class, msg));
return;
Expand All @@ -169,7 +168,7 @@ void onStart(MessageHeader header, MessageLite msg) {
// Set up the state cache
this.localStateStorage =
new LocalStateStorage(
header.hasFlag(MessageHeader.PARTIAL_STATE_FLAG),
startMessage.getPartialState(),
startMessage.getStateMapList().stream()
.collect(
Collectors.toMap(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

public class MessageHeader {

public static final short PARTIAL_STATE_FLAG = 0x0400;
private static final short DONE_FLAG = 0x0001;
private static final short REQUIRES_ACK_FLAG = 0x0001;

Expand Down Expand Up @@ -36,10 +35,6 @@ public long encode() {
return res;
}

public boolean hasFlag(short flag) {
return (this.flags & flag) > 0;
}

public MessageHeader copyWithFlags(short flag) {
return new MessageHeader(type, flag, length);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ public void greet(GreetingRequest request, StreamObserver<GreetingResponse> resp
}
}

private static final short PARTIAL_STATE = PARTIAL_STATE_FLAG;
private static final short COMPLETE_STATE = 0;
private static final Map.Entry<String, String> STATE_FRANCESCO = entry("STATE", "Francesco");
private static final MessageLite INPUT_TILL = inputMessage(greetingRequest("Till"));
private static final MessageLite GET_STATE_FRANCESCO = getStateMessage("STATE", "Francesco");
Expand All @@ -94,38 +92,33 @@ public void greet(GreetingRequest request, StreamObserver<GreetingResponse> resp
Stream<TestDefinition> definitions() {
return Stream.of(
testInvocation(new GetEmpty(), GreeterGrpc.getGreetMethod())
.withInput(COMPLETE_STATE, startMessage(1))
.withInput(INPUT_TILL)
.withInput(startMessage(1).setPartialState(false), INPUT_TILL)
.usingAllThreadingModels()
.expectingOutput(getStateEmptyMessage("STATE"), outputMessage(greetingResponse("true")))
.named("With complete state"),
testInvocation(new GetEmpty(), GreeterGrpc.getGreetMethod())
.withInput(PARTIAL_STATE, startMessage(1))
.withInput(INPUT_TILL)
.withInput(startMessage(1).setPartialState(true), INPUT_TILL)
.usingAllThreadingModels()
.expectingOutput(getStateMessage("STATE"), suspensionMessage(1))
.named("With partial state"),
testInvocation(new GetEmpty(), GreeterGrpc.getGreetMethod())
.withInput(PARTIAL_STATE, startMessage(2))
.withInput(INPUT_TILL, getStateEmptyMessage("STATE"))
.withInput(
startMessage(2).setPartialState(true), INPUT_TILL, getStateEmptyMessage("STATE"))
.usingAllThreadingModels()
.expectingOutput(outputMessage(greetingResponse("true")))
.named("Resume with partial state"),
testInvocation(new Get(), GreeterGrpc.getGreetMethod())
.withInput(COMPLETE_STATE, startMessage(1, STATE_FRANCESCO))
.withInput(INPUT_TILL)
.withInput(startMessage(1, STATE_FRANCESCO).setPartialState(false), INPUT_TILL)
.usingAllThreadingModels()
.expectingOutput(GET_STATE_FRANCESCO, OUTPUT_FRANCESCO)
.named("With complete state"),
testInvocation(new Get(), GreeterGrpc.getGreetMethod())
.withInput(PARTIAL_STATE, startMessage(1, STATE_FRANCESCO))
.withInput(INPUT_TILL)
.withInput(startMessage(1, STATE_FRANCESCO).setPartialState(true), INPUT_TILL)
.usingAllThreadingModels()
.expectingOutput(GET_STATE_FRANCESCO, OUTPUT_FRANCESCO)
.named("With partial state"),
testInvocation(new Get(), GreeterGrpc.getGreetMethod())
.withInput(PARTIAL_STATE, startMessage(1))
.withInput(INPUT_TILL)
.withInput(startMessage(1).setPartialState(true), INPUT_TILL)
.usingAllThreadingModels()
.expectingOutput(getStateMessage("STATE"), suspensionMessage(1))
.named("With partial state without the state entry"),
Expand All @@ -139,8 +132,10 @@ Stream<TestDefinition> definitions() {
OUTPUT_FRANCESCO_TILL)
.named("With state in the state_map"),
testInvocation(new GetAppendAndGet(), GreeterGrpc.getGreetMethod())
.withInput(PARTIAL_STATE, startMessage(1))
.withInput(INPUT_TILL, completionMessage(1, "Francesco"))
.withInput(
startMessage(1).setPartialState(true),
INPUT_TILL,
completionMessage(1, "Francesco"))
.usingAllThreadingModels()
.expectingOutput(
getStateMessage("STATE"),
Expand All @@ -158,8 +153,10 @@ Stream<TestDefinition> definitions() {
OUTPUT_FRANCESCO)
.named("With state in the state_map"),
testInvocation(new GetClearAndGet(), GreeterGrpc.getGreetMethod())
.withInput(PARTIAL_STATE, startMessage(1))
.withInput(INPUT_TILL, completionMessage(1, "Francesco"))
.withInput(
startMessage(1).setPartialState(true),
INPUT_TILL,
completionMessage(1, "Francesco"))
.usingAllThreadingModels()
.expectingOutput(
getStateMessage("STATE"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ public class ProtoUtils {
*/
public static MessageHeader headerFromMessage(MessageLite msg) {
if (msg instanceof Protocol.StartMessage) {
return new MessageHeader(
MessageType.StartMessage, MessageHeader.PARTIAL_STATE_FLAG, msg.getSerializedSize());
return new MessageHeader(MessageType.StartMessage, (short) 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.CompletionMessage) {
return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize());
}
Expand All @@ -38,7 +37,8 @@ public static Protocol.StartMessage.Builder startMessage(int entries) {
return Protocol.StartMessage.newBuilder()
.setId(ByteString.copyFromUtf8("abc"))
.setDebugId("abc")
.setKnownEntries(entries);
.setKnownEntries(entries)
.setPartialState(true);
}

@SafeVarargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ public void testInvoke(String serviceName) throws IOException {
.setDebugId("123")
.setId(ByteString.copyFromUtf8("123"))
.setKnownEntries(1)
.setPartialState(true)
.build(),
Protocol.PollInputStreamEntryMessage.newBuilder()
.setValue(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@ public void cancel() {}

static MessageHeader headerFromMessage(MessageLite msg) {
if (msg instanceof Protocol.StartMessage) {
return new MessageHeader(
MessageType.StartMessage, MessageHeader.PARTIAL_STATE_FLAG, msg.getSerializedSize());
return new MessageHeader(MessageType.StartMessage, (short) 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.CompletionMessage) {
return new MessageHeader(MessageType.CompletionMessage, (short) 0, msg.getSerializedSize());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ private void handle(
.setDebugId(invocationId)
.setId(ByteString.copyFromUtf8(invocationId))
.setKnownEntries(1)
.setPartialState(true)
.build();
List<MessageLite> inputMessages = List.of(startMessage, inputMessage);

Expand Down

0 comments on commit 0349058

Please sign in to comment.