Skip to content

Commit

Permalink
Fixed consuming realtime runs w/streams after the run is already fini…
Browse files Browse the repository at this point in the history
…shed
  • Loading branch information
ericallam committed Dec 13, 2024
1 parent 3379fc4 commit 7014057
Show file tree
Hide file tree
Showing 8 changed files with 416 additions and 46 deletions.
19 changes: 11 additions & 8 deletions packages/core/src/v3/apiClient/runStream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -116,11 +116,14 @@ export function runShapeStream<TRunTypes extends AnyRunTypes>(
{ once: true }
);

const runStreamInstance = zodShapeStream(SubscribeRunRawShape, url, {
...options,
signal: abortController.signal,
});

const $options: RunSubscriptionOptions = {
runShapeStream: zodShapeStream(SubscribeRunRawShape, url, {
...options,
signal: abortController.signal,
}),
runShapeStream: runStreamInstance.stream,
stopRunShapeStream: runStreamInstance.stop,
streamFactory: new VersionedStreamSubscriptionFactory(version1, version2),
abortController,
...options,
Expand Down Expand Up @@ -215,7 +218,7 @@ export class ElectricStreamSubscription implements StreamSubscription {

async subscribe(): Promise<ReadableStream<unknown>> {
return zodShapeStream(SubscribeRealtimeStreamChunkRawShape, this.url, this.options)
.pipeThrough(
.stream.pipeThrough(
new TransformStream({
transform(chunk, controller) {
controller.enqueue(chunk.value);
Expand Down Expand Up @@ -298,12 +301,12 @@ export interface RunShapeProvider {

export type RunSubscriptionOptions = RunShapeStreamOptions & {
runShapeStream: ReadableStream<SubscribeRunRawShape>;
stopRunShapeStream: () => void;
streamFactory: StreamSubscriptionFactory;
abortController: AbortController;
};

export class RunSubscription<TRunTypes extends AnyRunTypes> {
private unsubscribeShape?: () => void;
private stream: AsyncIterableStream<RunShape<TRunTypes>>;
private packetCache = new Map<string, any>();
private _closeOnComplete: boolean;
Expand All @@ -330,7 +333,7 @@ export class RunSubscription<TRunTypes extends AnyRunTypes> {
) {
console.log("Closing stream because run is complete");

this.options.abortController.abort();
this.options.stopRunShapeStream();
}
},
},
Expand All @@ -342,7 +345,7 @@ export class RunSubscription<TRunTypes extends AnyRunTypes> {
if (!this.options.abortController.signal.aborted) {
this.options.abortController.abort();
}
this.unsubscribeShape?.();
this.options.stopRunShapeStream();
}

[Symbol.asyncIterator](): AsyncIterator<RunShape<TRunTypes>> {
Expand Down
91 changes: 61 additions & 30 deletions packages/core/src/v3/apiClient/stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,40 @@ export type ZodShapeStreamOptions = {
signal?: AbortSignal;
};

export type ZodShapeStreamInstance<TShapeSchema extends z.ZodTypeAny> = {
stream: AsyncIterableStream<z.output<TShapeSchema>>;
stop: () => void;
};

export function zodShapeStream<TShapeSchema extends z.ZodTypeAny>(
schema: TShapeSchema,
url: string,
options?: ZodShapeStreamOptions
) {
const stream = new ShapeStream<z.input<TShapeSchema>>({
): ZodShapeStreamInstance<TShapeSchema> {
const abortController = new AbortController();

options?.signal?.addEventListener(
"abort",
() => {
abortController.abort();
},
{ once: true }
);

const shapeStream = new ShapeStream({
url,
headers: {
...options?.headers,
"x-trigger-electric-version": "1.0.0-beta.1",
},
fetchClient: options?.fetchClient,
signal: options?.signal,
signal: abortController.signal,
});

const readableShape = new ReadableShapeStream(stream);
const readableShape = new ReadableShapeStream(shapeStream);

return readableShape.stream.pipeThrough(
new TransformStream({
const stream = readableShape.stream.pipeThrough(
new TransformStream<unknown, z.output<TShapeSchema>>({
async transform(chunk, controller) {
const result = schema.safeParse(chunk);

Expand All @@ -46,6 +61,14 @@ export function zodShapeStream<TShapeSchema extends z.ZodTypeAny>(
},
})
);

return {
stream: stream as AsyncIterableStream<z.output<TShapeSchema>>,
stop: () => {
console.log("Stopping zodShapeStream with abortController.abort()");
abortController.abort();
},
};
}

export type AsyncIterableStream<T> = AsyncIterable<T> & ReadableStream<T>;
Expand Down Expand Up @@ -104,14 +127,19 @@ class ReadableShapeStream<T extends Row<unknown> = Row> {
readonly #currentState: Map<string, T> = new Map();
readonly #changeStream: AsyncIterableStream<T>;
#error: FetchError | false = false;
#unsubscribe?: () => void;

stop() {
this.#unsubscribe?.();
}

constructor(stream: ShapeStreamInterface<T>) {
this.#stream = stream;

// Create the source stream that will receive messages
const source = new ReadableStream<Message<T>[]>({
start: (controller) => {
this.#stream.subscribe(
this.#unsubscribe = this.#stream.subscribe(
(messages) => controller.enqueue(messages),
this.#handleError.bind(this)
);
Expand All @@ -121,41 +149,44 @@ class ReadableShapeStream<T extends Row<unknown> = Row> {
// Create the transformed stream that processes messages and emits complete rows
this.#changeStream = createAsyncIterableStream(source, {
transform: (messages, controller) => {
messages.forEach((message) => {
const updatedKeys = new Set<string>();

for (const message of messages) {
if (isChangeMessage(message)) {
const key = message.key;
switch (message.headers.operation) {
case "insert": {
this.#currentState.set(message.key, message.value);
controller.enqueue(message.value);
// New row entirely
this.#currentState.set(key, message.value);
updatedKeys.add(key);
break;
}
case "update": {
const existingRow = this.#currentState.get(message.key);
if (existingRow) {
const updatedRow = {
...existingRow,
...message.value,
};
this.#currentState.set(message.key, updatedRow);
controller.enqueue(updatedRow);
} else {
this.#currentState.set(message.key, message.value);
controller.enqueue(message.value);
}
// Merge updates into existing row if any, otherwise treat as new
const existingRow = this.#currentState.get(key);
const updatedRow = existingRow
? { ...existingRow, ...message.value }
: message.value;
this.#currentState.set(key, updatedRow);
updatedKeys.add(key);
break;
}
}
} else if (isControlMessage(message)) {
if (message.headers.control === "must-refetch") {
this.#currentState.clear();
this.#error = false;
}
}
}

if (isControlMessage(message)) {
switch (message.headers.control) {
case "must-refetch":
this.#currentState.clear();
this.#error = false;
break;
}
// Now enqueue only one updated row per key, after all messages have been processed.
for (const key of updatedKeys) {
const finalRow = this.#currentState.get(key);
if (finalRow) {
controller.enqueue(finalRow);
}
});
}
},
});
}
Expand Down
12 changes: 9 additions & 3 deletions packages/react-hooks/src/hooks/useRealtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,13 @@ export function useRealtimeRun<TTask extends AnyTask>(
}
}, [runId, mutateRun, abortControllerRef, apiClient, setError]);

const hasCalledOnCompleteRef = useRef(false);

// Effect to handle onComplete callback
useEffect(() => {
if (isComplete && options?.onComplete && run) {
if (isComplete && run && options?.onComplete && !hasCalledOnCompleteRef.current) {
options.onComplete(run, error);
hasCalledOnCompleteRef.current = true;
}
}, [isComplete, run, error, options?.onComplete]);

Expand Down Expand Up @@ -261,10 +264,13 @@ export function useRealtimeRunWithStreams<
}
}, [runId, mutateRun, mutateStreams, streamsRef, abortControllerRef, apiClient, setError]);

const hasCalledOnCompleteRef = useRef(false);

// Effect to handle onComplete callback
useEffect(() => {
if (isComplete && options?.onComplete && run) {
if (isComplete && run && options?.onComplete && !hasCalledOnCompleteRef.current) {
options.onComplete(run, error);
hasCalledOnCompleteRef.current = true;
}
}, [isComplete, run, error, options?.onComplete]);

Expand Down Expand Up @@ -593,7 +599,7 @@ async function processRealtimeRunWithStreams<
nextStreamData[type] = [...(existingDataRef.current[type] || []), ...chunks];
}

await mutateStreamData(nextStreamData);
mutateStreamData(nextStreamData);
}, throttleInMs);

for await (const part of subscription.withStreams<TStreams>()) {
Expand Down
Loading

0 comments on commit 7014057

Please sign in to comment.