Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 34 additions & 19 deletions packages/restate-sdk/src/context_impl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ import {
} from "./types/errors.js";
import type { Client, SendClient } from "./types/rpc.js";
import {
defaultSerde,
HandlerKind,
makeRpcCallProxy,
makeRpcSendProxy,
Expand Down Expand Up @@ -96,6 +95,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
private readonly outputPump: OutputPump;
private readonly runClosuresTracker: RunClosuresTracker;
readonly promisesExecutor: PromisesExecutor;
readonly defaultSerde: Serde<any>;

constructor(
readonly coreVm: vm.WasmVM,
Expand All @@ -108,6 +108,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
inputReader: ReadableStreamDefaultReader<Uint8Array>,
outputWriter: WritableStreamDefaultWriter<Uint8Array>,
readonly journalValueCodec: JournalValueCodec,
defaultSerde?: Serde<any>,
private readonly asTerminalError?: (error: any) => TerminalError | undefined
) {
this.rand = new RandImpl(input.random_seed, () => {
Expand All @@ -131,6 +132,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
this.runClosuresTracker,
this.promiseExecutorErrorCallback.bind(this)
);
this.defaultSerde = defaultSerde ?? serde.json;
}

cancel(invocationId: InvocationId): void {
Expand All @@ -146,7 +148,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
WasmCommandType.AttachInvocation,
() => {},
(vm) => vm.sys_attach_invocation(invocationId),
SuccessWithSerde(serde ?? defaultSerde(), this.journalValueCodec),
SuccessWithSerde(serde ?? this.defaultSerde, this.journalValueCodec),
Failure
);
}
Expand All @@ -173,7 +175,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
() => {},
(vm) => vm.sys_get_state(name),
VoidAsNull,
SuccessWithSerde(serde ?? defaultSerde(), this.journalValueCodec)
SuccessWithSerde(serde ?? this.defaultSerde, this.journalValueCodec)
);
}

Expand All @@ -191,7 +193,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
WasmCommandType.SetState,
() =>
this.journalValueCodec.encode(
(serde ?? defaultSerde()).serialize(value)
(serde ?? this.defaultSerde).serialize(value)
),
(vm, bytes) => vm.sys_set_state(name, bytes)
);
Expand Down Expand Up @@ -344,21 +346,36 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
}

serviceClient<D>({ name }: ServiceDefinitionFrom<D>): Client<Service<D>> {
return makeRpcCallProxy((call) => this.genericCall(call), name);
return makeRpcCallProxy(
(call) => this.genericCall(call),
this.defaultSerde,

name
);
}

objectClient<D>(
{ name }: VirtualObjectDefinitionFrom<D>,
key: string
): Client<VirtualObject<D>> {
return makeRpcCallProxy((call) => this.genericCall(call), name, key);
return makeRpcCallProxy(
(call) => this.genericCall(call),
this.defaultSerde,
name,
key
);
}

workflowClient<D>(
{ name }: WorkflowDefinitionFrom<D>,
key: string
): Client<Workflow<D>> {
return makeRpcCallProxy((call) => this.genericCall(call), name, key);
return makeRpcCallProxy(
(call) => this.genericCall(call),
this.defaultSerde,
name,
key
);
}

public serviceSendClient<D>(
Expand All @@ -367,6 +384,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
): SendClient<Service<D>> {
return makeRpcSendProxy(
(send) => this.genericSend(send),
this.defaultSerde,
name,
undefined,
opts?.delay
Expand All @@ -380,6 +398,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
): SendClient<VirtualObject<D>> {
return makeRpcSendProxy(
(send) => this.genericSend(send),
this.defaultSerde,
name,
key,
opts?.delay
Expand All @@ -393,6 +412,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
): SendClient<Workflow<D>> {
return makeRpcSendProxy(
(send) => this.genericSend(send),
this.defaultSerde,
name,
key,
opts?.delay
Expand All @@ -412,7 +432,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
nameOrAction,
actionSecondParameter
);
const serde = options?.serde ?? defaultSerde();
const serde = options?.serde ?? this.defaultSerde ?? this.defaultSerde;

// Prepare the handle
let handle: number;
Expand Down Expand Up @@ -586,7 +606,7 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
awakeable.handle,
completeSignalPromiseUsing(
VoidAsUndefined,
SuccessWithSerde(serde, this.journalValueCodec),
SuccessWithSerde(serde ?? this.defaultSerde, this.journalValueCodec),
Failure
)
),
Expand All @@ -606,8 +626,8 @@ export class ContextImpl implements ObjectContext, WorkflowContext {
} else {
value =
payload !== undefined
? defaultSerde().serialize(payload)
: defaultSerde().serialize(null);
? this.defaultSerde.serialize(payload)
: this.defaultSerde.serialize(null);
}
return this.journalValueCodec.encode(value);
},
Expand Down Expand Up @@ -793,7 +813,7 @@ class DurablePromiseImpl<T> implements DurablePromise<T> {
private readonly name: string,
serde?: Serde<T>
) {
this.serde = serde ?? defaultSerde();
this.serde = serde ?? (this.ctx.defaultSerde as unknown as Serde<T>);
}

then<TResult1 = T, TResult2 = never>(
Expand Down Expand Up @@ -981,7 +1001,7 @@ const VoidAsUndefined: Completer = (value, prom) => {
};

function SuccessWithSerde<T>(
serde?: Serde<T>,
serde: Serde<T>,
journalCodec?: JournalValueCodec,
transform?: <U>(success: T) => U
): Completer {
Expand All @@ -995,12 +1015,7 @@ function SuccessWithSerde<T>(
} else {
buffer = value.Success;
}
let val: T;
if (serde) {
val = serde.deserialize(buffer);
} else {
val = defaultSerde<T>().deserialize(buffer);
}
let val = serde.deserialize(buffer);
if (transform) {
val = transform(val);
}
Expand Down
66 changes: 42 additions & 24 deletions packages/restate-sdk/src/endpoint/components.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ import type {
WorkflowOptions,
} from "../types/rpc.js";
import { HandlerKind } from "../types/rpc.js";
import { millisOrDurationToMillis } from "@restatedev/restate-sdk-core";
import type { Serde } from "@restatedev/restate-sdk-core";
import { millisOrDurationToMillis, serde } from "@restatedev/restate-sdk-core";

//
// Interfaces
Expand All @@ -44,17 +45,22 @@ export interface ComponentHandler {
// Service
//

function handlerInputDiscovery(handler: HandlerWrapper): d.InputPayload {
function handlerInputDiscovery(
handler: HandlerWrapper,
defaultSerde: Serde<any>
): d.InputPayload {
const serde = handler.inputSerde ?? defaultSerde;

let contentType = undefined;
let jsonSchema = undefined;

if (handler.inputSerde.jsonSchema) {
jsonSchema = handler.inputSerde.jsonSchema;
contentType = handler.accept ?? handler.inputSerde.contentType;
if (serde.jsonSchema) {
jsonSchema = serde.jsonSchema;
contentType = handler.accept ?? serde.contentType;
} else if (handler.accept) {
contentType = handler.accept;
} else if (handler.inputSerde.contentType) {
contentType = handler.inputSerde.contentType;
} else if (serde.contentType) {
contentType = serde.contentType;
} else {
// no input information
return {};
Expand All @@ -67,20 +73,20 @@ function handlerInputDiscovery(handler: HandlerWrapper): d.InputPayload {
};
}

function handlerOutputDiscovery(handler: HandlerWrapper): d.OutputPayload {
function handlerOutputDiscovery(
handler: HandlerWrapper,
defaultSerde: Serde<any>
): d.OutputPayload {
const serde = handler.outputSerde ?? defaultSerde;

let contentType = undefined;
let jsonSchema = undefined;

if (handler.outputSerde.jsonSchema) {
jsonSchema = handler.outputSerde.jsonSchema;
contentType =
handler.contentType ??
handler.outputSerde.contentType ??
"application/json";
} else if (handler.contentType) {
contentType = handler.contentType;
} else if (handler.outputSerde.contentType) {
contentType = handler.outputSerde.contentType;
if (serde.jsonSchema) {
jsonSchema = serde.jsonSchema;
contentType = serde.contentType ?? "application/json";
} else if (serde.contentType) {
contentType = serde.contentType;
} else {
// no input information
return { setContentTypeIfEmpty: false };
Expand Down Expand Up @@ -116,7 +122,10 @@ export class ServiceComponent implements Component {
([name, handler]) => {
return {
name,
...commonHandlerOptions(handler.handlerWrapper),
...commonHandlerOptions(
handler.handlerWrapper,
this.options?.defaultSerde ?? serde.json
),
} satisfies d.Handler;
}
);
Expand Down Expand Up @@ -188,7 +197,10 @@ export class VirtualObjectComponent implements Component {
return {
name,
ty: handler.kind() === HandlerKind.EXCLUSIVE ? "EXCLUSIVE" : "SHARED",
...commonHandlerOptions(handler.handlerWrapper),
...commonHandlerOptions(
handler.handlerWrapper,
this.options?.defaultSerde ?? serde.json
),
} satisfies d.Handler;
}
);
Expand Down Expand Up @@ -263,7 +275,10 @@ export class WorkflowComponent implements Component {
this.options?.workflowRetention !== undefined
? millisOrDurationToMillis(this.options?.workflowRetention)
: undefined,
...commonHandlerOptions(handler.handlerWrapper),
...commonHandlerOptions(
handler.handlerWrapper,
this.options?.defaultSerde ?? serde.json
),
} satisfies d.Handler;
}
);
Expand Down Expand Up @@ -382,10 +397,13 @@ function commonServiceOptions(
};
}

function commonHandlerOptions(wrapper: HandlerWrapper) {
function commonHandlerOptions(
wrapper: HandlerWrapper,
defaultSerde: Serde<any>
) {
return {
input: handlerInputDiscovery(wrapper),
output: handlerOutputDiscovery(wrapper),
input: handlerInputDiscovery(wrapper, defaultSerde),
output: handlerOutputDiscovery(wrapper, defaultSerde),
journalRetention:
wrapper.journalRetention !== undefined
? millisOrDurationToMillis(wrapper.journalRetention)
Expand Down
1 change: 1 addition & 0 deletions packages/restate-sdk/src/endpoint/handlers/generic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ export class GenericHandler implements RestateHandler {
inputReader,
outputWriter,
journalValueCodec,
service.options?.defaultSerde,
service.options?.asTerminalError
);

Expand Down
Loading
Loading