Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,5 @@ next-env.d.ts

# Code Coverage
js/plugins/compat-oai/coverage/
js/plugins/firebase/database-debug.log
js/plugins/firebase/firestore-debug.log
13 changes: 12 additions & 1 deletion js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,25 @@ import { lazy } from './async.js';
import { getContext, runWithContext, type ActionContext } from './context.js';
import type { ActionType, Registry } from './registry.js';
import { parseSchema } from './schema.js';
import {
type ActionStreamInput,
type ActionStreamSubscriber,
type StreamManager,
} from './streaming.js';
import {
SPAN_TYPE_ATTR,
runInNewSpan,
setCustomMetadataAttributes,
} from './tracing.js';

export { StatusCodes, StatusSchema, type Status } from './statusTypes.js';
export type { JSONSchema7 };
export { InMemoryStreamManager, StreamNotFoundError } from './streaming.js';
export type {
ActionStreamInput,
ActionStreamSubscriber,
JSONSchema7,
StreamManager,
};

const makeNoopAbortSignal = () => new AbortController().signal;

Expand Down
27 changes: 26 additions & 1 deletion js/core/src/async.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export interface Task<T> {
}

/** Utility for creating Tasks. */
function createTask<T>(): Task<T> {
export function createTask<T>(): Task<T> {
let resolve: unknown, reject: unknown;
const promise = new Promise<T>(
(res, rej) => ([resolve, reject] = [res, rej])
Expand Down Expand Up @@ -126,3 +126,28 @@ export function lazy<T>(fn: () => T | PromiseLike<T>): PromiseLike<T> {
}
});
}

/**
* A queue for asynchronous tasks. The queue ensures that only one task runs at a time in order.
*/
export class AsyncTaskQueue {
private last: Promise<any> = Promise.resolve();

/**
* Adds a task to the queue.
* The task will be executed when its turn comes up in the queue.
* @param task A function that returns a value or a PromiseLike.
*/
enqueue(task: () => any | PromiseLike<any>) {
this.last = this.last.then(() => lazy(task)).then((res) => res);
// Prevent unhandled promise rejections.
this.last.catch(() => {});
}

/**
* Waits for all tasks currently in the queue to complete.
*/
async merge() {
await this.last;
}
}
155 changes: 155 additions & 0 deletions js/core/src/streaming.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/**
* Copyright 2024 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import { GenkitError } from './error';

export class StreamNotFoundError extends GenkitError {
constructor(message: string) {
super({ status: 'NOT_FOUND', message });
this.name = 'StreamNotFoundError';
}
}

export interface ActionStreamInput<S, O> {
write(chunk: S): Promise<void>;
done(output: O): Promise<void>;
error(err: any): Promise<void>;
}

export type ActionStreamSubscriber<S, O> = {
onChunk: (chunk: S) => void;
onDone: (output: O) => void;
onError: (error: any) => void;
};

export interface StreamManager {
open<S, O>(streamId: string): Promise<ActionStreamInput<S, O>>;
subscribe<S, O>(
streamId: string,
options: ActionStreamSubscriber<S, O>
): Promise<{ unsubscribe: () => void }>;
}

type StreamState<S, O> =
| {
status: 'open';
chunks: S[];
subscribers: ActionStreamSubscriber<S, O>[];
lastTouched: number;
}
| { status: 'done'; chunks: S[]; output: O; lastTouched: number }
| { status: 'error'; chunks: S[]; error: any; lastTouched: number };

export class InMemoryStreamManager implements StreamManager {
private streams: Map<string, StreamState<any, any>> = new Map();

constructor(private options: { ttlSeconds?: number } = {}) {}

private _cleanup() {
const ttl = (this.options.ttlSeconds ?? 5 * 60) * 1000;
const now = Date.now();
for (const [streamId, stream] of this.streams.entries()) {
if (stream.status !== 'open' && now - stream.lastTouched > ttl) {
this.streams.delete(streamId);
}
}
}

async open<S, O>(streamId: string): Promise<ActionStreamInput<S, O>> {
this._cleanup();
if (this.streams.has(streamId)) {
throw new Error(`Stream with id ${streamId} already exists.`);
}
this.streams.set(streamId, {
status: 'open',
chunks: [],
subscribers: [],
lastTouched: Date.now(),
});

return {
write: async (chunk: S) => {
const stream = this.streams.get(streamId);
if (stream?.status === 'open') {
stream.chunks.push(chunk);
stream.subscribers.forEach((s) => s.onChunk(chunk));
stream.lastTouched = Date.now();
}
},
done: async (output: O) => {
const stream = this.streams.get(streamId);
if (stream?.status === 'open') {
this.streams.set(streamId, {
status: 'done',
chunks: stream.chunks,
output,
lastTouched: Date.now(),
});
stream.subscribers.forEach((s) => s.onDone(output));
}
},
error: async (err: any) => {
const stream = this.streams.get(streamId);
if (stream?.status === 'open') {
stream.subscribers.forEach((s) => s.onError(err));
this.streams.set(streamId, {
status: 'error',
chunks: stream.chunks,
error: err,
lastTouched: Date.now(),
});
}
},
};
}

async subscribe<S, O>(
streamId: string,
subscriber: ActionStreamSubscriber<S, O>
): Promise<{ unsubscribe: () => void }> {
const stream = this.streams.get(streamId);
if (!stream) {
throw new StreamNotFoundError(`Stream with id ${streamId} not found.`);
}

if (stream.status === 'done') {
for (const chunk of stream.chunks) {
subscriber.onChunk(chunk);
}
subscriber.onDone(stream.output);
} else if (stream.status === 'error') {
for (const chunk of stream.chunks) {
subscriber.onChunk(chunk);
}
subscriber.onError(stream.error);
} else {
stream.chunks.forEach((chunk) => subscriber.onChunk(chunk));
stream.subscribers.push(subscriber);
}

return {
unsubscribe: () => {
const currentStream = this.streams.get(streamId);
if (currentStream?.status === 'open') {
const index = currentStream.subscribers.indexOf(subscriber);
if (index > -1) {
currentStream.subscribers.splice(index, 1);
}
}
},
};
}
}
76 changes: 75 additions & 1 deletion js/core/tests/async_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,81 @@

import * as assert from 'assert';
import { describe, it } from 'node:test';
import { LazyPromise } from '../src/async';
import { AsyncTaskQueue, LazyPromise } from '../src/async';
import { sleep } from './utils';

describe('AsyncTaskQueue', () => {
it('should execute tasks in order', async () => {
const queue = new AsyncTaskQueue();
const results: number[] = [];

queue.enqueue(async () => {
await sleep(10);
results.push(1);
});
queue.enqueue(() => {
results.push(2);
});

await queue.merge();

assert.deepStrictEqual(results, [1, 2]);
});

it('should handle empty queue', async () => {
const queue = new AsyncTaskQueue();
await queue.merge();
// No error should be thrown.
});

it('should handle tasks added after merge is called', async () => {
const queue = new AsyncTaskQueue();
const results: number[] = [];

queue.enqueue(async () => {
await sleep(10);
results.push(1);
});

queue.enqueue(() => {
results.push(2);
});

assert.deepStrictEqual(results, []);

await queue.merge();

assert.deepStrictEqual(results, [1, 2]);
});

it('should propagate errors', async () => {
const queue = new AsyncTaskQueue();
const error = new Error('test error');

queue.enqueue(() => {
throw error;
});

await assert.rejects(queue.merge(), error);
});

it('should execute tasks without calling merge', async () => {
const queue = new AsyncTaskQueue();
const results: number[] = [];

queue.enqueue(async () => {
await sleep(20);
results.push(1);
});
queue.enqueue(() => {
results.push(2);
});

await sleep(30);

assert.deepStrictEqual(results, [1, 2]);
});
});

describe('LazyPromise', () => {
it('call its function lazily', async () => {
Expand Down
4 changes: 4 additions & 0 deletions js/core/tests/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,7 @@ export class TestSpanExporter implements SpanExporter {
return Promise.resolve();
}
}

export function sleep(ms: number) {
return new Promise((resolve) => setTimeout(resolve, ms));
}
8 changes: 8 additions & 0 deletions js/genkit/src/beta.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,13 @@
* limitations under the License.
*/

export {
InMemoryStreamManager,
StreamNotFoundError,
type ActionStreamInput,
type ActionStreamSubscriber,
type StreamManager,
} from '@genkit-ai/core';
export { AsyncTaskQueue, lazy } from '@genkit-ai/core/async';
export * from './common.js';
export { GenkitBeta, genkit, type GenkitBetaOptions } from './genkit-beta.js';
Loading