Skip to content

Commit

Permalink
Update the entry ack for side effects to use the new mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
slinkydeveloper committed Dec 18, 2023
1 parent 2ebdbe0 commit 7aabbbb
Show file tree
Hide file tree
Showing 5 changed files with 56 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,10 @@ public void onNext(InvocationFlow.InvocationInput invocationInput) {
// We check the instance rather than the state, because the user code might still be
// replaying, but the network layer is already past it and is receiving completions from the
// runtime.
Protocol.CompletionMessage completionMessage = (Protocol.CompletionMessage) msg;

// If ack, give it to side effect publisher
if (completionMessage.getResultCase()
== Protocol.CompletionMessage.ResultCase.RESULT_NOT_SET) {
this.sideEffectAckStateMachine.tryHandleSideEffectAck(completionMessage.getEntryIndex());
} else {
this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg);
}
this.readyResultStateMachine.offerCompletion((Protocol.CompletionMessage) msg);
} else if (msg instanceof Protocol.EntryAckMessage) {
this.sideEffectAckStateMachine.tryHandleSideEffectAck(
((Protocol.EntryAckMessage) msg).getEntryIndex());
} else {
this.incomingEntriesStateMachine.offer(msg);
}
Expand Down
32 changes: 15 additions & 17 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageHeader.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,14 @@

public class MessageHeader {

private static final short DONE_FLAG = 0x0001;
private static final short REQUIRES_ACK_FLAG = 0x0001;
static final short DONE_FLAG = 0x0001;
static final int REQUIRES_ACK_FLAG = 0x8000;

private final MessageType type;
private final short flags;
private final int flags;
private final int length;

public MessageHeader(MessageType type, short flags, int length) {
public MessageHeader(MessageType type, int flags, int length) {
this.type = type;
this.flags = flags;
this.length = length;
Expand Down Expand Up @@ -57,15 +57,15 @@ public static MessageHeader parse(long encoded) throws ProtocolException {

public static MessageHeader fromMessage(MessageLite msg) {
if (msg instanceof Protocol.SuspensionMessage) {
return new MessageHeader(MessageType.SuspensionMessage, (short) 0, msg.getSerializedSize());
return new MessageHeader(MessageType.SuspensionMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.ErrorMessage) {
return new MessageHeader(MessageType.ErrorMessage, (short) 0, msg.getSerializedSize());
return new MessageHeader(MessageType.ErrorMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.EntryAckMessage) {
return new MessageHeader(MessageType.EntryAckMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.PollInputStreamEntryMessage) {
return new MessageHeader(
MessageType.PollInputStreamEntryMessage, (short) 0, msg.getSerializedSize());
return new MessageHeader(MessageType.PollInputStreamEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.OutputStreamEntryMessage) {
return new MessageHeader(
MessageType.OutputStreamEntryMessage, (short) 0, msg.getSerializedSize());
return new MessageHeader(MessageType.OutputStreamEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.GetStateEntryMessage) {
return new MessageHeader(
MessageType.GetStateEntryMessage,
Expand All @@ -75,11 +75,9 @@ public static MessageHeader fromMessage(MessageLite msg) {
: 0,
msg.getSerializedSize());
} else if (msg instanceof Protocol.SetStateEntryMessage) {
return new MessageHeader(
MessageType.SetStateEntryMessage, (short) 0, msg.getSerializedSize());
return new MessageHeader(MessageType.SetStateEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.ClearStateEntryMessage) {
return new MessageHeader(
MessageType.ClearStateEntryMessage, (short) 0, msg.getSerializedSize());
return new MessageHeader(MessageType.ClearStateEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.SleepEntryMessage) {
return new MessageHeader(
MessageType.SleepEntryMessage,
Expand All @@ -95,7 +93,7 @@ public static MessageHeader fromMessage(MessageLite msg) {
msg.getSerializedSize());
} else if (msg instanceof Protocol.BackgroundInvokeEntryMessage) {
return new MessageHeader(
MessageType.BackgroundInvokeEntryMessage, (short) 0, msg.getSerializedSize());
MessageType.BackgroundInvokeEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Protocol.AwakeableEntryMessage) {
return new MessageHeader(
MessageType.AwakeableEntryMessage,
Expand All @@ -106,10 +104,10 @@ public static MessageHeader fromMessage(MessageLite msg) {
msg.getSerializedSize());
} else if (msg instanceof Protocol.CompleteAwakeableEntryMessage) {
return new MessageHeader(
MessageType.CompleteAwakeableEntryMessage, (short) 0, msg.getSerializedSize());
MessageType.CompleteAwakeableEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Java.CombinatorAwaitableEntryMessage) {
return new MessageHeader(
MessageType.CombinatorAwaitableEntryMessage, (short) 0, msg.getSerializedSize());
MessageType.CombinatorAwaitableEntryMessage, 0, msg.getSerializedSize());
} else if (msg instanceof Java.SideEffectEntryMessage) {
return new MessageHeader(
MessageType.SideEffectEntryMessage, REQUIRES_ACK_FLAG, msg.getSerializedSize());
Expand Down
8 changes: 8 additions & 0 deletions sdk-core/src/main/java/dev/restate/sdk/core/MessageType.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public enum MessageType {
CompletionMessage,
SuspensionMessage,
ErrorMessage,
EntryAckMessage,

// IO
PollInputStreamEntryMessage,
Expand All @@ -43,6 +44,7 @@ public enum MessageType {
public static final short COMPLETION_MESSAGE_TYPE = 0x0001;
public static final short SUSPENSION_MESSAGE_TYPE = 0x0002;
public static final short ERROR_MESSAGE_TYPE = 0x0003;
public static final short ENTRY_ACK_MESSAGE_TYPE = 0x0004;
public static final short POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0400;
public static final short OUTPUT_STREAM_ENTRY_MESSAGE_TYPE = 0x0401;
public static final short GET_STATE_ENTRY_MESSAGE_TYPE = 0x0800;
Expand All @@ -66,6 +68,8 @@ public Parser<? extends MessageLite> messageParser() {
return Protocol.SuspensionMessage.parser();
case ErrorMessage:
return Protocol.ErrorMessage.parser();
case EntryAckMessage:
return Protocol.EntryAckMessage.parser();
case PollInputStreamEntryMessage:
return Protocol.PollInputStreamEntryMessage.parser();
case OutputStreamEntryMessage:
Expand Down Expand Up @@ -104,6 +108,8 @@ public short encode() {
return SUSPENSION_MESSAGE_TYPE;
case ErrorMessage:
return ERROR_MESSAGE_TYPE;
case EntryAckMessage:
return ENTRY_ACK_MESSAGE_TYPE;
case PollInputStreamEntryMessage:
return POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE;
case OutputStreamEntryMessage:
Expand Down Expand Up @@ -142,6 +148,8 @@ public static MessageType decode(short value) throws ProtocolException {
return SuspensionMessage;
case ERROR_MESSAGE_TYPE:
return ErrorMessage;
case ENTRY_ACK_MESSAGE_TYPE:
return EntryAckMessage;
case POLL_INPUT_STREAM_ENTRY_MESSAGE_TYPE:
return PollInputStreamEntryMessage;
case OUTPUT_STREAM_ENTRY_MESSAGE_TYPE:
Expand Down
27 changes: 27 additions & 0 deletions sdk-core/src/test/java/dev/restate/sdk/core/MessageHeaderTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright (c) 2023 - Restate Software, Inc., Restate GmbH
//
// This file is part of the Restate Java SDK,
// which is released under the MIT license.
//
// You can find a copy of the license in file LICENSE in the root
// directory of this repository or package, or at
// https://github.com/restatedev/sdk-java/blob/main/LICENSE
package dev.restate.sdk.core;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Test;

public class MessageHeaderTest {

@Test
void requiresAckFlag() {
assertThat(
new MessageHeader(
MessageType.InvokeEntryMessage,
MessageHeader.DONE_FLAG | MessageHeader.REQUIRES_ACK_FLAG,
2)
.encode())
.isEqualTo(0x0C01_8001_0000_0002L);
}
}
4 changes: 2 additions & 2 deletions sdk-core/src/test/java/dev/restate/sdk/core/ProtoUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ public static Protocol.CompletionMessage completionMessage(int index, Throwable
.build();
}

public static Protocol.CompletionMessage ackMessage(int index) {
return Protocol.CompletionMessage.newBuilder().setEntryIndex(index).build();
public static Protocol.EntryAckMessage ackMessage(int index) {
return Protocol.EntryAckMessage.newBuilder().setEntryIndex(index).build();
}

public static Protocol.SuspensionMessage suspensionMessage(Integer... indexes) {
Expand Down

0 comments on commit 7aabbbb

Please sign in to comment.